生成对抗网络(GAN)的损失函数推导与最小最大博弈解释
我将为你讲解生成对抗网络(GAN)损失函数的数学推导,并详细解释其最小最大博弈原理。这个题目与之前讲过的"生成对抗网络(GAN)的原理与训练过程"不同,这里我们聚焦于损失函数本身的推导和理论解释。
题目描述
生成对抗网络(GAN)由生成器G和判别器D组成,通过对抗训练学习数据分布。其核心是最小最大博弈(minimax game),需要推导出GAN的原始损失函数,并解释为什么这个优化问题能引导生成器生成真实数据。
解题过程
步骤1:问题建模与基本设定
GAN的目标是让生成器G学会从随机噪声z(通常来自标准正态分布)生成与真实数据相似的数据。判别器D的任务是区分真实数据(来自真实分布p_data)和生成数据(来自生成分布p_g)。
设:
- 真实数据分布:p_data(x)
- 噪声先验分布:p_z(z)(通常是标准正态分布)
- 生成器:G(z; θ_g),参数为θ_g
- 判别器:D(x; θ_d),输出标量表示"x来自真实数据"的概率
步骤2:判别器的优化目标(固定生成器)
判别器D希望最大化自己判断的准确率:
- 对真实数据x∼p_data(x),D(x)应该接近1
- 对生成数据G(z)(z∼p_z(z)),D(G(z))应该接近0
这可以表示为最大化以下目标函数:
V(D, G) = E_{x∼p_data}[log D(x)] + E_{z∼p_z}[log(1 - D(G(z)))]
解释:
- 第一项E[log D(x)]:真实数据被判别为真的对数概率期望,希望最大化
- 第二项E[log(1 - D(G(z)))]:生成数据被判别为假的对数概率期望,希望最大化
注意:这里使用对数是因为在伯努利分布下,最大化正确分类的对数似然等价于最小化交叉熵损失。
步骤3:生成器的优化目标(固定判别器)
生成器G希望生成的样本能够"欺骗"判别器,即让判别器D认为生成数据是真实的。因此,生成器希望最大化D(G(z)),或者等价地,最小化log(1 - D(G(z)))。
所以生成器的目标是最小化:
E_{z∼p_z}[log(1 - D(G(z)))]
但注意:在实际训练中,通常采用最大化E_{z∼p_z}[log D(G(z))]作为替代,因为log(1 - D(G(z)))在训练初期梯度很小,会导致训练困难。
步骤4:最小最大博弈的形式化
结合判别器和生成器的目标,我们得到GAN的原始最小最大博弈:
min_G max_D V(D, G) = E_{x∼p_data}[log D(x)] + E_{z∼p_z}[log(1 - D(G(z)))]
直观解释:
- 内层max_D:判别器D试图最大化V(D,G),即更好地区分真实和生成数据
- 外层min_G:生成器G试图最小化V(D,G),即让判别器难以区分
步骤5:理论最优解的推导
对于一个固定的生成器G,我们来求最优判别器D*_G。
对于任意输入x,判别器的优化问题是:
max_D {p_data(x) log D(x) + p_g(x) log(1 - D(x))}
其中p_g(x)是生成器定义的分布。
这是一个关于D(x)的函数最大化问题。令f(y) = a log y + b log(1 - y),其中a = p_data(x), b = p_g(x)。
求导:f'(y) = a/y - b/(1 - y) = 0
解得:y* = a/(a+b) = p_data(x) / [p_data(x) + p_g(x)]
所以最优判别器为:
D*_G(x) = p_data(x) / [p_data(x) + p_g(x)]
理解:最优判别器输出的是"x来自真实数据而非生成数据"的后验概率。
步骤6:全局最优解的证明
将最优判别器D*_G代回目标函数V(D,G):
C(G) = max_D V(D,G) = V(D*G, G)
= E{x∼p_data}[log(p_data(x)/(p_data(x)+p_g(x)))] + E_{x∼p_g}[log(p_g(x)/(p_data(x)+p_g(x)))]
这可以重写为:
C(G) = E_{x∼p_data}[log(p_data(x)/((p_data(x)+p_g(x))/2))] + E_{x∼p_g}[log(p_g(x)/((p_data(x)+p_g(x))/2))] - 2log2
= KL(p_data || (p_data+p_g)/2) + KL(p_g || (p_data+p_g)/2) - 2log2
= 2JSD(p_data || p_g) - 2log2
其中JSD是Jensen-Shannon散度,是KL散度的对称版本。
由于JSD(p_data||p_g) ≥ 0,且当且仅当p_data = p_g时取最小值0,所以:
min_G C(G) = min_G 2JSD(p_data||p_g) - 2log2 = -2log2
当且仅当p_g = p_data时取得。
结论:在生成器G达到全局最优时,生成分布p_g等于真实分布p_data,此时最优判别器D*_G(x) = 1/2,即完全无法区分真实和生成数据。
步骤7:实际训练中的损失函数
在实际训练中,GAN通常采用交替优化:
-
判别器更新(固定G):
最大化:L_D = E_{x∼p_data}[log D(x)] + E_{z∼p_z}[log(1 - D(G(z)))]
等价于最小化二元交叉熵损失 -
生成器更新(固定D):
原始形式:最小化L_G = E_{z∼p_z}[log(1 - D(G(z)))]
实际常用:最大化L_G = E_{z∼p_z}[log D(G(z))] (非饱和版本,梯度更友好)
步骤8:最小最大博弈的直观理解
可以将GAN的训练看作:
- 判别器:如警察,学习识别假币
- 生成器:如伪造者,学习制造更逼真的假币
- 两者在对抗中不断提升:更好的判别器迫使生成器改进,更好的生成器迫使判别器改进
- 最终平衡时,生成器制造的"假币"与真币无法区分
关键点总结
- GAN的原始目标函数是最小最大博弈:min_G max_D V(D,G)
- 最优判别器是后验概率:D*_G(x) = p_data(x) / [p_data(x) + p_g(x)]
- 全局最优解对应p_g = p_data,此时判别器输出始终为0.5
- 理论上的优化目标等价于最小化p_data和p_g之间的Jensen-Shannon散度
- 实际训练采用交替优化和梯度下降,生成器损失常用非饱和版本以避免梯度消失
这个推导揭示了为什么对抗训练能够引导生成器学习到真实数据分布,也为理解GAN的训练动态和困难(如模式坍塌、训练不稳定)提供了理论基础。