Pyro SVI 第三部分:ELBO梯度估计器

Setup

我们已经定义了由观察x,隐变量z通过 $p_{\theta}(x,z) = p_{\theta}(x|z)p_{\theta}(z)$ 形式组成的model。我们同时也定义了 $q_{\phi}(z)$ 组成的guide。在这里,$\phi$ 和 $\theta$ 都是对model和guide的变分参数。(特别的,这里没有随机变量需要使用贝叶斯)。

我们想要通过最大化ELBO的方法最大化对数证据 $log\:p_{\theta}(x)$ :

通过这样做,我们将能在 ${\theta,\phi}$ 的空间中对ELBO使用(随机)梯度下降(这个方法的前期工作请参考【1,2】)。这样我们需要计算无偏估计:

我们如何在随机函数model()guide()上使用它呢?为了简化字符,让我们概括这个讨论一点,直接讨论我们如何计算任意一个损失函数 $f(z)$ 的期望的梯度。让我们同时忽略 $\theta$ 和 $\phi$ 之间的不同。所以我们希望计算:

让我们从最简单的例子开始。

简单:可重新参数化的随机变量

假设我们可以重新参数化下面这个:

重点是,我们去除了期望之中所有对 $\phi$ 的依赖; $q(\epsilon)$ 是一个对 $\phi$ 没有依赖的固定的分布。这样的重新参数化可以在很多分布上使用(例如,正态分布);详细讨论见参考【3】。在这个例子中我们直接向期望传递梯度,来得到:

假设 $f({\cdot})$ 和 $g({\cdot})$ 都足够平滑,我们现在可以通过这个期望的蒙特卡洛估计,得到感兴趣的梯度的无偏估计。

进阶:没有可重新参数化的随机变量

如果我们不能做上面的重新参数化呢?不幸的是,我们感兴趣的很多分布(distributions of interest?),比方说所有离散分布,都属于这种情况。在这个例子中,我们的估计器面对更复杂一点的形式。

我们先拓展感兴趣的梯度为:

然后使用链式规则写成:

在此处我们将遇到一个问题。我们知道如何从 $q(\cdot)$ 中生成取样 ——我们直接跑guide就行—— 但是 $\triangledown_{\phi}q_{\phi}(z)$ 甚至不是一个合格的概率密度。那么我们需要改变这个公式使得它是一个可用 $q(\cdot)$ 表示的期望。这可以通过下面这个identity轻松完成:

这样就让我们可以重新熟悉梯度为:

这个梯度估计器的形式——也有叫做加强估计器(REINFORCE estimator)或者评分函数估计器(score function estimator)或似然比例估计器(likelihood ratio estimator)——很适合做蒙特卡洛估计。

注意到,一个打包此结果的(方便实现的)方法是引入一个替代目标函数(surrogate objective function):

这里横线意味着这一项保持不变(这样就不会被 $\phi$ 求导了)。为了得到一个(单个样本)蒙特卡洛梯度估计,我们对隐随机变量采样,计算替代对象,并求导。结果是 $\triangledown_{\phi}\mathbb{E}_{q_{\phi}(z)}[f_{\phi}(z)]$ 的无偏估计。 等式:

方差或者为什么我希望我在做MLE Deep Learning

我们现在有了对损失函数期望的无偏梯度估计的配方了。不幸的是,更宽泛的情况,即我们的 $q(\cdot)$ 包括了不可重新参数画的随机变量的情况,这个估计器会有很高的方差。实际上,在很多感兴趣的例子中,这个方差高到使得估计器无法有效使用。所以我们需要一个方法来减少方差(详细讨论见参考【4】)。我们将跟随两个策略。第一个方法利用了 $f(\cdot)$ 的特殊结构。第二个方法有效地利用了之前的 $\mathbb{E}_{q_{\phi}(z)}[f_{\phi}(z)]$ 的估计来减少方差。这个方法和利用动量进行随机梯度下降有异曲同工之妙。

通过依赖结构减少方差

在上面的讨论中,我们受困于一个宽泛的损失函数 $f_{\phi}(z)$。我们可以在此路上继续前进(我们将讨论的这个方法适用于大多数情况)但是对于具体性,让我们放大来看。在随机变分推断中,我们对特定的损失函数形式感兴趣:

其中,我们分割对数比率 $log\:\frac{p_{\theta}(x,z)}{q_{\phi}(z)}$ 为一个观测的对数似然片段和不同隐随机变量 $\{z_i\}$ 的和。我们同时引入了 $Pa_p(\cdot)$ 和 $Pa_q(\cdot)$ 来表示model和guide中给定的随机变量的parents。(读者可能会担心在广泛的随机函数中什么是合适表示依赖的字符;这里我们简单地意为常规的一个单一执行踪迹中的前者依赖)。重点是,损失函数中不同的项会对随机变量 ${z_i}$ 有不同的依赖,而这是我们可以利用的地方。

简短的来说,对于任何不可重新参数画的隐随机变量 $z_i$,替代目标就是:

这使得我们可以去除 $\overline(f_{phi}(z))$ 的某些项,同时可以依旧得到一个无偏的梯度估计器;进一步,这样做通常将会减少方差。特别的(详见参考【4】)我们可以移去 $\overline(f_{phi}(z))$ 中不是隐变量 $z_i$ 的下游(downstream of latent variables)的项(downstream指的是guide的依赖结构)。注意到这个技巧——处理某些随机变量来减少方差——通常会被称为Rao-Blackwellization。

在Pyro中,所有的此类逻辑都自动被SVI类所纳入。具体的来说,只要我们使用TraceGraph_EBLO类,Pyro就将持续追踪model和guide里执行踪迹中的依赖结构,并建立移除了所有非必要项的替代目标:

1
svi = SVI(model, guide, optimizer, TraceGraph_EBLO())

注意到,利用这种依赖信息将花费额外的计算量,所以TraceGraph_EBLO应该只在模型中有不可重新参数化的随机变量时才使用;在多数应用中Trace_ELBO就足够了。

一个使用Rao-Blackwellization的例子

假设我们有一个k个元素组成的混合高斯模型。对于每个数据,我们:(1)先对各个组成的分布 $k \in [1, …, K]$ 取样,(2)使用第k个组成分布观察数据。最简单的方法是用如下方式写下模型:

1
2
3
ks = pyro.sample("k", dist.Categorical(probs).to_event(1))
pyro.sample("obs", dist.Normal(locs[ks], scale).to_event(1),
obs=data)

因为使用者没有注意去标记模型中的任何条件独立,PryoSVI类生成的梯度估计器无法利用Rao-Blackwellization,而其结果就是梯度估计器将伴有高方差。要解决这个问题,使用者需要明确的标记条件独立。值得高兴的是,这个工作量并不大:
1
2
3
4
5
6
7
# mark conditional independence
# (assumed to be along the rightmost tensor dimension)
with pyro.plate("foo", data.size(-1)):
ks = pyro.sample("k", dist.Categorical
(probs))
pyro.sample("obs", dist.Normal(locs[ks], scale),
obs=data)

就是这么简单。

其他:Pyro中的依赖跟踪

最后,讨论一个依赖跟踪(dependency tracking)。在包括任意Python代码的随机函数中做依赖追踪是有点麻烦的。目前Pyro实现的方法和WebPPL中用的相类似(参考【5】)。简单的来说,这里使用了一个依赖序列顺序的保守的依赖概念?。如果在一个给定的随机函数中随机变量 $z_2$ 在 $z_1$ 之后,那么 $z_2$ 可能依赖于 $z_1$, 进而假设他们是依赖的。要解决这种过度草率的假设行为,Pyro包含了platemarkov两种结构,用他们申明的事物是独立的(见之前的教程)。对于不可重新参数化变量的情况,使用者(当可以用时)完全利用SVI提供的方差减少的的结构是相当重要的。在一些例子中,也需要考虑(如果可能的话)重新安排一下随机函数中随机变量的顺序。在未来的Pyro版本中,我们期望添加更好的依赖跟踪概念。

用数据依赖基线减少方差

第二个减少ELBO梯度估计器的方差的方法名叫基线(baselines)(见参考【6】)。它实际利用了一点同样的,基于上文所提到的方差减少方法的数学。不同的是现在,我们添加项而不是减少项。基本上就是,我们不再移去那些会增加方差但又期望为0的项,我们将会添加一些特别选择的0期望的项,让他们来减少方差。这就是一个控制方差的方法。

更具体的,这个想法就是利用以下这个事实:对于任意常量b,下面这个表达一直成立:

因为 $q(\cdot)$ 是被正则过的,我们可以得到:

这就意味着我们可以将替代目标中的任意项:

替换为

这样做不会影响我们的梯度估计器的均值,但是它能影响到方差。如果我们巧妙的选择b,我们就能减少方差。事实上,b并不需要是一个常量:它可以依赖任意一个随机选择的 $z_i$ 的上游(或侧游)。

Pyro中的基线

在随机变分推断中,使用者有很多方法可以让Pyro来使用基线。因为基线可以被附加在任何不可重新参数化的随机变量上,目前的基线接口baseline interface在pyro.sample一层中。具体的,基线接口使用参数baseline,这个参数是申明基线选项的字典。注意到,只有在guide中申明对于采样语句的基线,它才有意义,在model则无用。

衰减均值基线 Decaying Average Baselines

最简单的基线可以从一个从 $\overline{f_{\phi}(z)}$ 的最近的样本的动态均值中建立。在Pyro中,这类基线可以按如下方式调用:

1
2
3
z = pyro.sample("z", dist.Bernoulli(...),
infer=dict(baseline={'use_decaying_avg_baseline':True,
'baseline_beta': 0.95}))

可选参数baseline_beta申明了衰减均值的衰减速度(默认值是0.90)。

神经基线 Neural Baselines

在一些情况中,一个衰减均值基线表现的很好。但在其他的的情况下,使用一个依赖上游随机性的基线,对于得到更好的方差减少是相当重要的。一个有力的方法是建立一个可以使用神经网络来通过学习调整的基线。Pyro提供了两条申明这种基线的途径(可以在AIR教程中见到扩展的例子)。

首先使用者需要决定机选将那些输入将被使用(比如,当前的考虑的数据,或者之前采样的随机变量)。然后,使用者需要建立一个nn.Module,他将封装基线的计算。他将长这个样:

1
2
3
4
5
6
7
8
9
10
class BaselineNN(nn.Module):
def __init__(self, dim_input, dim_hidden):
super().__init__()
self.linear = nn.Linear(dim_input, dim_hidden)
# ... finish initialization

def forward(self, x):
hidden = self.linear(x)
# ... do more computations ...
return baseline

然后,假设BaselineNN对象baseline_module已经被初始化过了,在guide中我们有如下:

1
2
3
4
5
6
def guide(x):  # here x is the current mini-batch of data
pyro.module("my_baseline", baseline_module)
# ... other computations ...
z = pyro.sample("z", dist.Bernoulli(...),
infer=dict(baseline={"nn_baseline": baseline_module,
'nn_baseline_input': x}))

这里,参数nn_baseline告诉Pyro使用哪个nn.Module来建立基线。在后端的参数nn_baseline_input被传入forward方法中来计算基线b。注意到基线module需要又一个pyro.module被Pyro注册,这样Pyro才能知道module中要训练的参数。

本质上来说,Pyro建立如下形式的损失:

它将被用于神经网络调整参数。没有理论证明这个是在此处最佳的损失函数,但实际上,它表现的相当的好。就像衰减均值基线,主要想法是一个可以追踪 $\overline{f_{\phi}(z)}$ 均值的基线可以帮助减少方差。实际上,SVI在基线损失中下降一步,就是在ELBO上下降一步。

注意,在实际中,对基线参数使用一组不同的学习超参(比如,更大的学习率)是相当重要的。在Pyro中,这可以用如下方式实现:

1
2
3
4
5
6
def per_param_args(module_name, param_name):
if 'baseline' in param_name or 'baseline' in module_name:
return {"lr": 0.010}
else:
return {"lr": 0.001}
optimizer = optim.Adam(per_param_args)

注意到,为了让整个过程正确,基线参数应该只通过基线损失被优化。相似的,model和guide的参数应该只通过ELBO被优化。为了保证这个情况,SVI从autograd graph中分离了进入ELBO的基线b。同时,因为神经基线的输入可能依赖于model和guide的参数,这些输入同样应该在被送入神经网络前,从autograd graph中分离。

最后,另一种方式申明神经基线的方法是直接使用参数baseline_value:

1
2
3
b = # do baseline computations
z = pyro.sample("z", dist.Bernoulli(...),
infer=dict(baseline={'baseline_value': b}))

如上可行,除了在这个例子中使用者需要自己保证任何连接b与model和guide中参数的autograd tape需要被切断。换言之,任何依赖 $\theta$ 或者 $\phi$ 的对b的输入需要被从autograd graph中用detach()来分离。

一个使用基线的复杂例子

记得在SVI的第一个教程中,我们考虑了一个抛硬币的伯努利-贝塔模型。因为beta随机变量是一个不可重新参数化(或者说不容易重新参数化)的变量,对应的ELBO梯度变得噪声相当大。那时,我们使用了Beta分布来提供(近似)重新参数化的梯度来解决这个问题。现在我们已经展示了,在Beta分布被当作不可重新参数化(这样ELBO梯度估计器就相当于打分函数)时,一个简单的衰减均值基线可以减少方差。尽管我们使用了这个方法,我们仍然用plate来写我们全向量化的model。

作为替代直接对比梯度方差,我们将看到它让SVI收敛要多少步。对于这种特殊的model(因为他是共轭的),我们可以计算明确的后验。那么要评估基线在此处的用处,我们预设如下简单的实验。我们用一组特定的变分参数初始化guide。然后我们待变分参数达到后验的参数的一定范围之内后做SVI。我们同时使用和不使用衰减均值基线。然后我们比较梯度下降的步数。下面是完整的代码:

(因为分隔使用了plateuse_decaying_avg_baseline,这段代码将和SVI教程的第一第二部分很相似,我们就不逐行解释了)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
# Pyro also has a reparameterized Beta distribution so we import
# the non-reparameterized version to make our point
from pyro.distributions.testing.fakes import NonreparameterizedBeta
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO
import sys

# enable validation (e.g. validate parameters of distributions)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
max_steps = 2 if smoke_test else 10000


def param_abs_error(name, target):
return torch.sum(torch.abs(target - pyro.param(name))).item()


class BernoulliBetaExample:
def __init__(self, max_steps):
# the maximum number of inference steps we do
self.max_steps = max_steps
# the two hyperparameters for the beta prior
self.alpha0 = 10.0
self.beta0 = 10.0
# the dataset consists of six 1s and four 0s
self.data = torch.zeros(10)
self.data[0:6] = torch.ones(6)
self.n_data = self.data.size(0)
# compute the alpha parameter of the exact beta posterior
self.alpha_n = self.data.sum() + self.alpha0
# compute the beta parameter of the exact beta posterior
self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)
# initial values of the two variational parameters
self.alpha_q_0 = 15.0
self.beta_q_0 = 15.0

def model(self, use_decaying_avg_baseline):
# sample `latent_fairness` from the beta prior
f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
# use plate to indicate that the observations are
# conditionally independent given f and get vectorization
with pyro.plate("data_plate"):
# observe all ten datapoints using the bernoulli likelihood
pyro.sample("obs", dist.Bernoulli(f), obs=self.data)

def guide(self, use_decaying_avg_baseline):
# register the two variational parameters with pyro
alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0),
constraint=constraints.positive)
beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0),
constraint=constraints.positive)
# sample f from the beta variational distribution
baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
'baseline_beta': 0.90}
# note that the baseline_dict specifies whether we're using
# decaying average baselines or not
pyro.sample("latent_fairness", NonreparameterizedBeta(alpha_q, beta_q),
infer=dict(baseline=baseline_dict))

def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
# clear the param store in case we're in a REPL
pyro.clear_param_store()
# setup the optimizer and the inference algorithm
optimizer = optim.Adam({"lr": .0005, "betas": (0.93, 0.999)})
svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

# do up to this many steps of inference
for k in range(self.max_steps):
svi.step(use_decaying_avg_baseline)
if k % 100 == 0:
print('.', end='')
sys.stdout.flush()

# compute the distance to the parameters of the true posterior
alpha_error = param_abs_error("alpha_q", self.alpha_n)
beta_error = param_abs_error("beta_q", self.beta_n)

# stop inference early if we're close to the true posterior
if alpha_error < tolerance and beta_error < tolerance:
break

print("\nDid %d steps of inference." % k)
print(("Final absolute errors for the two variational parameters " +
"were %.4f & %.4f") % (alpha_error, beta_error))

# do the experiment
bbe = BernoulliBetaExample(max_steps=max_steps)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

Sample output:

1
2
3
4
5
6
7
8
Doing inference with use_decaying_avg_baseline=True
....................
Did 1932 steps of inference.
Final absolute errors for the two variational parameters were 0.7997 & 0.0800
Doing inference with use_decaying_avg_baseline=False
..................................................
Did 4908 steps of inference.
Final absolute errors for the two variational parameters were 0.7991 & 0.2532

对于这个特别的运行,我们可以看到基线基本减半了SVI步数。结果是随机的,每次都会变换,但是这仍是个正奋人心的结果。尽管这是个设计过的例子,但对于特定的model和guide组合,基线可以提供确实的好处。

参考

[1] Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber

[2] Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei

[3] Auto-Encoding Variational Bayes, Diederik P Kingma, Max Welling

[4] Gradient Estimation Using Stochastic Computation Graphs, John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[5] Deep Amortized Inference for Probabilistic Programs Daniel Ritchie, Paul Horsfall, Noah D. Goodman

[6] Neural Variational Inference and Learning in Belief Networks Andriy Mnih, Karol Gregor