强化学习:Deep Q-Learning

在文章强化学习:蒙特卡洛和时序差分中介绍了使用时序差分解决强化学习问题的一种经典方法:Q-Learning,但是该方法适用于有限的状态集合\mathcal{S},一般来说使用n行(n = number of states)和m列(m= number of actions)的矩阵(Q table)来储存action-value function的值,如下图所示:

对于连续的状态集合\mathcal{S},上述方法就不能适用了,这时可以引入神经网络来估计Q的值,即Deep Q-Learning,如下图所示:

接下来介绍Deep Q-Learning中常用的几种技巧,用于提升学习效果:

  • Stack States:对于连续的状态集合,单个状态不能很好地描述整体的状况,可以使用CNN或RNN模型同时考虑多个连续的状态和它们之间的依赖关系。例如下图所示,要判断黑色方块的移动方向,仅凭一副图像是无法判断的,需要连续的多幅图像才能判断出黑色方块在向右移动。
  • Experience Replay:如下图所示,为了防止算法在训练过程中忘记了之前场景获得的经验,可以创建一个Replay Buffer,不断回放之前的场景对算法进行训练;另一方面,相邻的场景之间(例如[S_{t},A_{t},R_{t+1},S_{t+1}][S_{t+1},A_{t+1},R_{t+2},S_{t+2}])有着一定的相关性,为了防止算法被固定在某些特定的状态空间,从Replay Buffer中随机抽样选取场景进行训练可打乱场景之间的顺序,减少相邻场景的相关性。
  • Target Network:不同于Q-Learning中不断更新Q(S,A)的值使之趋向于对应最优策略\pi^*q^*(S,A),在Deep Q-Learning中通过对计算Q值的神经网络的权重系数\vec w不断进行更新,使神经网络成为函数q^*的近似。定义损失函数J(\vec w)=[\hat{q}(S,A,\vec w)-q_{\text{reference}}(S,A)]^2其中q_{\text{reference}}叫做TD target,为网络\hat q的近似真值,此时有\Delta \vec{w}=-\frac{\alpha}{2}\nabla_{\vec{w}}J(\vec w),\quad\vec w\gets\vec{w}+\Delta \vec{w}假设场景为[S,A,R,S^{\prime}]q_{\text{reference}}可表示为q_{\text{reference}}(S,A)=R+\gamma \max_a\hat q(S^{\prime},a,\vec w)上述表示方式的一个缺点是TD target会随着参数\vec w的更新不断变化,加大了训练难度,减小了训练效率。一个解决方法是使用不同的神经网络(Target Network)来计算TD target,为了简化算法,Target Network通常会使用与\hat q相同的网络架构,但是每隔一定的步数才会更新其参数,对J(\vec w)计算梯度,参数更新步长可以写为
  • Double DQN:为了解决TD target对\hat{q}的真值可能高估的问题(参考公式\mathbb{E}_{\pi}[\max_a Q(s,a)] \geq \max_a \mathbb{E}_{\pi}[Q(s,a)]),可以将动作a的选择过程与TD target的计算过程进行分割,此时的TD target可以写为:
  • Dueling DQN:如下图所示,相对于直接计算action-value function \hat q(s,a),即Q(s,a),可以将Q(s,a)分解为state-value function V(s)与advantage function A(s,a)之和,即直接对V(s)进行估计,再叠加上采取不同行动对其的影响。在实际计算中,还需对A(s,a)添加额外的限制,下面例举两种常用的方式:\begin{cases}\text{Option 1: }\max_a A(s,a)=0 \\ \text{Option 2: }\sum_a A(s,a)=0\end{cases}

代码实现

使用GYM强化学习环境,以其中的一个任务BreakoutDeterministic-v0为基础进行训练:

一、搭建强化学习环境

点击查看代码

二、搭建卷积神经网络

点击查看代码

三、建立Replay Buffer

点击查看代码

四、训练和验证网络

点击查看代码

Leave a Comment

Your email address will not be published. Required fields are marked *