19. Generative Models
Unsupervised Learning¶
Supervised Learning: given
Unsupervised Learning: given
后者的例子有很多,比如:
- 聚类:K-clustering
- 给定一组点,我们可以从中学习这些点是由哪几个(高斯)概率分布生成的。其中,聚类的中心就是多元正态概率分布的期望。
- 降维:PCA
- 给定一组点,从中学习哪几个维度有最大的信息量。使用数学的语言,就是
- 假设
,其中 就是我们取到的点,而 分别是高斯分布和高斯噪声(其中高斯噪声的 covariance matrix 是 identity matrix 的若干倍,i.e. 噪声向量的不同 element 之间独立) - 使用 PCA,可以最大化互信息
, where ,i.e. 投影的矩阵
- 假设
- 具体详见Wikipedia
- 给定一组点,从中学习哪几个维度有最大的信息量。使用数学的语言,就是
- 密度估计:和聚类是类似的,就是由概率分布产生的点反推概率分布
- 自编码器:……
后者的好处就是:由于无类别的数据远多于有类别的数据,因此 supervised learning 需要耗费手工标注的成本,而 unsupervised learning 无需标注。
Difference Between Models¶
我们可以将 model 分为 3 类:
- Discriminative Model:
- (Unconditional) Generative Model:
- Conditional Generative Model:
Discriminative Model¶
注意到,同一个样本空间内,概率之间是互相竞争的,如果你多,那么我必须少;反之亦然。因此,如果使用
- 虽然上面两张图片和猫、狗没有任何关系,但是由于
中,不同标签之间是竞争关系,因此,必须有一个标签的概率比较大,即使这个标签对于这张图片而言非常离谱。
Generative Model¶
如果采用生成式模型,那么互相竞争的就不是标签,而是
由于图片空间
- e.g. is a dog more likely to sit or stand? How about a 3-legged dog vs 3-armed monkey?
另外,generative model 实际上是有能力去 reject 一个图片的。比如说该张图片的出现概率非常低,那么,模型可以认为这张图片就是不正常的图片,从而拒绝。相比之下,discriminative model 并不能拒绝图片,因为所有标签概率之和总是 1。
Conditional Generative Model¶
By Bayesian Law:
如上图,可以通过上述两个模型,加上
- 和 generative model 一样,conditional generative model 可以拒绝一张图片。比如说即使是
也在 threshold 之下。
Comparison¶
最根本的两种模型,就是 discriminative model 以及 generative model。
Generative Models¶
Generative model 的目标函数就是:
- 也就是假设所有样本均独立,然后采用最大似然推断
或者可以转换一下:
而对于某个
而这样的依赖关系,和 RNN 完全一样:
因此,我们可以使用 RNN 来生成这些概率,然后取对数加起来,得到的就是
然后,我们将每一张图片都这样做,然后加起来,就得到了
Taxonomy¶
Autoregressive Model¶
Pixel RNN¶
Pixel RNN,顾名思义,就是在 Pixel 上使用 RNN。给出依赖关系(i.e. 依赖左侧和上侧的 pixels),就可以像上图一样“泛洪式”地生成。
缺点:速度太慢。
Pixel CNN¶
另外还有一种 Pixel CNN,也就是 Mask CNN 的一个变种:
基本想法就是:我们通过多个 resolution-preserving, masked convolution layer,在保持形状的同时(保证最后的输出中,每一个点对应着原始图片的一个点),遮盖住像素的后继(避免该像素看到后面的像素)进行卷积。最后,我们对输出取一个对数,然后将所有输出相加,就得到了这一张图片的“概率”。
当然,这虽然相比 pixel RNN 有所进步,但是其实也并不快。
Pros and Cons¶
当然,还有很多 tricks 可以提升这些 RNN 的效率。
Autoencoders¶
Autoencoder 很简单,如上图,就是先通过 downsampling 将 vector 压缩,再通过 upsampling 将 vector 还原,然后 loss 就是 original vector 与 output vector 的 L2 metric。
但是,这种 autoencoder 无法判断图片出现的概率,因此不是 generative model。
另外,我们也不能够用它来生成图像。因为
因此,我们需要给
可惜:
因此,考虑贝叶斯公式:
因此,我们决定使用
也就是下面的结构:
然后,我们就可以得出:
然后,我们使用 expectation trick:
因此,我们找到了目标:找到这样的 encoder 和 decoder,使之能够最大化
VAE¶
How To Optimize?¶
KL Divergence¶
其中,
- 一般而言,z 这个高斯分布,我们就设成
。反正不同高斯分布之间就是一个线性变换的关系。
因此:
推导过程:
因此:
- 和上面略微有些出入,不过无关紧要。第一张图中的
可能是因为那张图里的 是标准差而不是方差。
另外,
- 比如 Gaussian 就是一个性质非常好,很容易计算的分布)
最后,我们设置
¶
我们使用数据集中的
How to Train (in Details)?¶
- 将数据集的一个 x 通过 encoder,得到
- 计算出该 x 为 condition 下的 KL 散度
- 同时,通过
这个高维高斯随机变量,采样得到若干个 z - 再将这些 z 通过 decoder,得到若干个
- 计算出
- 从而,通过 (2),(5),我们就得到了
的一个估计的下界 - 对每一个数据集里的 x 都重复这样的操作,就计算得到了这个数据集的 likelihood 的下界。我们的目标就是最大化这个 likelihood,梯度上升即可
How to Generate?¶
由于
Interpretability¶
通过调整里面的参数,可以得到不同的分布。由于 z 的每两个参数之间是独立的,因此可以这样得到某一个参数的实际含义。
Practice: Image Editing¶
不足之处,就是 VAE 会导致图片变模糊——这有可能是因为我们对图片采用了不恰当的假设,图片的真实高斯分布不是对角矩阵。
Pros and Cons of Autoregressive Models and VAE¶
VQ-VAE2¶
如上图:
- 使用多层的 encoder 以及 decoder
- 在每一层上,使用 PixelCNN
GAN¶
GAN 使用一个生成器和一个判别器:
其中,z 就是一个简单的分布(比如标准高斯/均匀),通过生成器 G,转化成另一个分布。我们的目标就是:让这个分布 G(z) 与图像的分布尽量靠近。
我们再训练一个鉴别器 D。
我们的目标函数就是:
其中:
又由于:
因此,令
因此:
进一步的推导如下:
由于 Jensen-Shannon Divergence 是一个度量,因此在 JSD 取到最小值,当且仅当 p 和 q 相等。
Problem: Gradient Vanishing¶
为了避免一开始的梯度消失(如图中蓝线所示),我们可以将
Caveats of GAN¶
-
上文中说到:令
。但是,实际上 可能根本表示不出来右边这个函数。 -
虽然取到最小值的时候,JSD 可以保证两个分布相等,但是这是非凸优化问题,不能保证收敛到最小值
-
特别是:JS Divergence 在两个分布相差很大的时候,很容易产生梯度消失的现象:
-
我们可以通过 Wasserstein GAN 来解决这个问题
- 期望最小化的目标函数不是 JSD,而是 Wasserstein Metric:
- 通过 Kantorovich-Rubinstein对偶定理 (Proof),让 Wasserstein 距离更容易计算:
- 从而,我们就可以优化
- 当然,需要保证
。这一点可以采用简单暴力的 clip 操作完成。
- 期望最小化的目标函数不是 JSD,而是 Wasserstein Metric:
-
-
最后一点,JSD 并不是 metric,因为不满足三角形不等式。当然这一点貌似无关紧要。
Practices of GAN¶
我们除了在 z 上随机取点以外,还可以通过 vector math 来进行更加复杂的操作。比如在两个 z 向量之间进行插值,就可以得到 very non-trivial interpolation between their corresponding images;乃至是 man w glasses - man w/o glasses + woman w/o glasses = woman w glasses。
Conditional GAN¶
为了在 GAN 中加入条件信息,我们需要向生成器和鉴别器中引入这个信息。
比如,可以通过下面的方式引入:
当然,也可以直接将标签 concatenate 到生成器的随机向量 z 后面——这个标签可以是经过 embedding 的,从而有着更多的信息。