Pyro SVI 第二部分:条件独立,二次取样和摊销

目标:拓展SVI到大数据集

对于一个有N个观察的模型,跑modelguide,建立ELBO这两件事涉及了估计对数概率密度函数,而这个估计的复杂程度挺糟糕的。如果我们想要扩展到大数据集,这将是个问题。幸运的是,ELBO目标天然支持二次采样,只要我们的model/guide有一些可供我们使用的条件独立的结构。比方说,当观察对隐变量条件独立,ELBO中的对数似然就可以被近似为:

其中 $I_M$ 是M个索引组成的mini-batch(M<N)(讨论详见[1,2].)好了,问题解决了!但我们能用它做什么呢?

在Pyro中标记条件独立

如果一个使用者想要做这件事,他首先需要保证model和guide都是按Pyro可以利用相关条件独立的方式写下的。Pyro给条件独立提供了两种原始类型:platemarkov。让我们看一下简单的那个

序列化 plate

让我们回到上个教程的例子中。为了方便,让我们在这里重新写一下主要的逻辑model:

1
2
3
4
def model(data):
# sample f from the beta prior
f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
# loop over the observed data using pyro.sample with the obs keyword argument

对于这个模型,观测对于隐随机变量latent_fairness是条件独立的。要在Pyro中标记这个,我们基本只需要用Pyro的plate来替代Python内置的range
1
2
3
4
5
6
7
def model(data):
# sample f from the beta prior
f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
# loop over the observed data [WE ONLY CHANGE THE NEXT LINE]
for i in pyro.plate("data_loop", len(data)):
# observe datapoint i using the bernoulli likelihood
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

我们看到pyro.platerange非常相似,只有一处主要的不同:plate每次调用都要求使用者提供唯一的名字。第二个参数则是一个整数,就像range那样。

目前为止都没问题。Pyro现在可以利用观测对隐随机变量的条件独立了。但是到底是如何实现的?本质上来说,pyro.plate是用一个上下文管理器(context manager)实现的。每次执行完for循环的主体部分时,我们进入一个新的(条件)独立的上下文,然后在for循环主体结尾处退出。让我们详细的看一下:

  • 因为每个观察到的pyro.sample表达式出现在一个不同的for循环的执行之中,Pyro会将每个观察都视作独立
  • 这个独立是对latent_fairness的一个适当的条件独立,因为latent_fairness是在for_loop之外采样的。

当使用plate时,有些陷阱需要避免。考虑一下,对上面代码的一个变种:

1
2
3
4
# WARNING do not do this
my_reified_list = list(pyro.plate("data_loop", len(data)))
for i in my_reified_list:
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

这不能使你得到你想要的结果,因为list()将会在一个的pyro.sample被调用前,就进入然后完全退出上下文data_loop。相似的,使用者需要注意不要泄露可变的计算操作到这个上下文管理器之外,这将会引起一起奇妙的bug。比方说,pyro.plate不支持那种会根据上次迭代而改变的局部model;在这种情况下,应该使用rangepyro.marko

向量化 plate

概念上,向量化(vectorized)plate和序列化plate,除了他是一个向量化操作(就像是torch.arange之于range)。对比序列化plate使用的for循环,他可能实现极大的加速。让我们看一下如何实现之前的例子。首先我们需要data是张量tensor的形式:

1
2
data = torch.zeros(10)
data[0:6] = torch.ones(6) # 6 heads and 4 tails

那么我们有:

1
2
with plate('observe_data'):
pyro.sample('obs', dist.Bernoulli(f), obs=data)

我们来点对点的对比一下它和相似的序列化的plate

  • 两者的形式都要求使用者明确一个唯一的名字。
  • 注意到这个代码片段只是引入了一个单个的(观测到的)随机变量(叫做obs),因为整个张量是同时被考虑的。
  • 因为没有必要使用迭代,所以没有必要明确plate上下文中张量的长度了。

注意到,我们之前所提到的序列化plate的陷阱同样适用于向量化plate

二次采样Subsampling

我们现在知道如何在Pyro中标记条件独立。它对它自身对它自身之内,是相当有用的(见SVI第三部分dependency tracking section),但是我们实现二次采样,这样我们就可以使用大数据集了。根据model和guide的结构,Pyro支持多种二次采样的途径,让我们一个一个看。

使用plate自动二次采样

我们先来看一下最简单的例子,在此之中我们会自由使用一个或两个更多的参数来二次采样:

1
2
for i in pyro.plate("data_loop", len(data), subsample_size=5):
pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

这就是全部了:我们就用了参数subsample_size。无论我们什么运行model(),我们现在只对从data中随机抽取的5个数据做对数似然;进一步,对数似然将会自动按比调整 $\frac{10}{5}= 2$ 。那么对于向量化plate呢?‘咒语’完全相似。
1
2
with plate('observe_data', size=10,subsample_size=5) as ind:
pyro.sample('obs', dist.Bernoulli(f), obs=data.index_select(0,ind))

重要的是,plate现在返回的是索引组成的张量,在这个例子中长度是5。注意打动,除了参数subsample_size,我们也加入了参数size,这样plate就可以知道data张量的大小了,这样它就可以计算正确的调整比例了。就像是序列化的plate,使用者需要使用plate提供的索引来选择正确的数据。

最后,要注意的是,如果data在GPU上的话,使用者必须给plate传入一个device参数。

使用plate自定义二次采样策略

上面每次model()运行时,plate都会采样新的二次采样索引。因为这次二次采样是无状态的,这将会导致一些问题:基本上对于一个足够大的数据集,即使在很多次的迭代后,我们将仍有不可忽略的概率无法抽到某些数据点。为了防止这事,使用者可以通过使用subsample参数来控制plate的二次采样的方法。详见文档

当只有局部随机变量时二次采样

我们记得一个联合概率密度的模型定义如下:

对于一个有如此依赖关系的模型,二次采样中使用的比例因子会同时作用于ELBO的所有式子中。普通VAE中就是如此的。这就是为什么VAE允许使用者完全接手二次采样,并输入mini-batches直接给model和guide;plate仍然被使用,但subsample_sizesubsample不行。想知道他们到底长什么样,参见VAE教程

当同时存在全局和局部随机变量时二次采样

在抛硬币的例子中,plate在model中出现,但是没有在guide出现,这是因为唯一一个要被二次采样的是观测。让我们看一个复杂点的例子,它的二次采样同时出现在model和guide中。要让这事儿简单,让我们接着有点抽象的讨论,并不急着写个完整的model和guide。

考虑一个定义如下的联合分布模型:

有N个观测 $\{x_i\}$ 和 N个局部的隐随机变量 $\{z_i\}$。还有一个全局的隐随机变量 $\beta$。我们的guide将被向量化成如下形式:

我们明确引入了N个局部变分参数 $\{\lambda_i\}$,而另外的变分参数被隐藏了起来。model和guide都是条件独立的。特别的,在model一边,给定了 $\{z_i\}$ ,观测 $\{x_i\}$ 是独立的。在guide一边,给定了变分参数 $\{\lambda_i\}$ 和 $\beta$,隐随机变量 $\{z_i\}$ 是独立的。要在Pyro中标记这些条件判断并做二次采样,我们需要在model和guide中都使用plate。让我们用序列化plate来概述一下基本逻辑(更具体的代码片段会包含pyro.param等表达式)。首先,模型如下:

1
2
3
4
5
6
7
8
def model(data):
beta = pyro.sample("beta", ...) # sample the global review
for i in pyro.plate("locals", len(data)):
z_i = pyro.sample("z_{}".format(i), ...)
# compute the parameter used to define the observation
# likelihood using the local random variable
theta_i = compute_something(z_i)
pyro.sample("obs_{}").format(i), dist.MyDist(theta_i), obs=data[i]

注意对比我们的抛硬币的例子,这里我们在plate循环的内外都有pyro.sample。对于下一个guide:

1
2
3
4
5
def guide(data):
beta = pyro.sample("data", ...) # sample the global review
for i in pyro.plate("locals", len(data), subsample_size=5):
# sample the local RVs
pyro.sample("z_{}".format(i), ..., lambda_i)

注意到索引只在guide中二次采样时采用到;在模型的执行中,Pyro的后端保证只会用到相同的索引集合。因为这个原因,所以subsample_size只需要在guide中申明。

摊销Amortization

让我们再考虑一下有全局和局部隐随机变量以及局部变分参数的模型:

对于中小型N,使用想这样的局部变分参数是个好方法。但如果N大了,那么我们试图优化的随着N增长的空间可能是个真正的问题。有一个方法是避免这种恶性的增长方式,那就是摊销Amortization。

它是这么工作的。与其引入局部变分参数,我们不如学习一个单参数函数 $f(\cdot)$ 并和变分分布合作,得到如下形式:

函数 $f(\cdot)$ ——其基本上就是,映射一个给定的观测到一组针对那个数据点的变分参数 —— 将需要足够的宽泛来准确地捕捉后验,但现在我们无需引入许多的变分参数,就可以处理大数据集了。这个方法还有另外一个好处:比如,在学习中时, $f(\cdot)$ 有效的让我们能去在不同数据点中使用统计的力量。注意到,这就是VAE中采用的方法。

张量形状和向量化plate

在本章节中,pyro.plate的使用受限于这些相对简单的例子。比如说,没有一个plate是在另一个plate之中的。为了能完全使用plate,使用者必须小心的使用Pyro的张量形状语义。对此的讨论见张量形状教程

Reference

[1] Stochastic Variational Inference, Matthew D. Hoffman, David M. Blei, Chong Wang, John Paisley

[2] Auto-Encoding Variational Bayes, Diederik P Kingma, Max Welling