同步操作将从 Liereyy/Satori 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
基于Transformer模型+强化学习训练的立直麻将agent
此项目刚开始不久,欢迎提PR,欢迎交流讨论
见tests/run_test.py
,令OP=2
,则对于手牌6678m3445p4567s44z, dora=2z
,给出$\pi(s)$如下:
将手牌、场况等信息编码为Nx1x34的张量作为ViT的输入,将麻将何切问题视为34-分类问题,吃/碰/立直问题视为2-分类问题,使用tenhou网站上爬下来的数据进行监督训练
使用单GPU或CPU训练:运行SL/SL.py
使用DDP在多GPU上训练:运行SL/SL_ddp.py
训练了4个模型:Discard, Chi, Peng, Reach(舍牌、吃、碰、立直)
由于杠出现频率很低,数据集很小不易训练,且往往导向不利的局面,不考虑杠的操作
Discard: 9M训练集,在2张RTX 3080上训练46个epoch后,在200K验证集上取得68%的准确率,由于验证集与测试集的accuracy,loss几乎完全一致 (这是可预期的,因为所有数据集服从一样的分布) ,这里只列出验证集:
Chi: 1.6M训练集,在2张RTX 3080上训练
Peng: 1M训练集,在2张RTX 3080上训练
Reach:
监督学习得到的模型在实际环境中已经能做出较好的决策,但对于未见过的状态会有偏差,需要RL进一步提高泛化能力
SL训练出的模型在《何切三百问》前165题上能取得58%的一致率
使用策略梯度算法,借鉴PPO算法,self-play,收集trajectory进行梯度更新,进一步优化模型
运行RL/ppo_discrete_main.py
优化目标: $$ \text{maximize}\quad J(\theta) =E[\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}r_\pi(S)] $$
求梯度并加上熵正则项:
$$ \nabla_\theta J(\theta) =E[\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \frac{\partial \log\pi_\theta(a_t|s_t)}{\partial\theta}q_\pi(s_t,a_t)] +\alpha\nabla_\theta H(\pi_\theta) $$
由估计
$$
q_\pi(s_t,a_t)=E[U_t|s_t,a_t]\approx u_t=\sum_{i=t}^{T}\gamma^{i-t} r_i
$$
一条trajectory定义为一整个round的agent_pos
做出的所有discard决策,reward仅在round结束后得到,即最后一个action的reward是round的收支点数/100(如放铳-5200点,此round的reward就是-52),其余的所有action的reward都是0,即$r_T=-52, r_t=0,t<T$,这样实际上有$u_t=\gamma^{T-t}r_T$
这里还借鉴PPO的clamp操作,避免较大的变化,在trust区域内更新,将重要性采样项替换为($\epsilon=0.2$): $$ p_t^{\text{clamp}}=\text{clamp}(\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon) $$
于是,梯度可写为: $$ \nabla_\theta J(\theta) =E[p_t^{\text{clamp}} \frac{\partial \log\pi_\theta(a_t|s_t)}{\partial\theta}u_t] +\alpha\nabla_\theta H(\pi_\theta) $$
与环境交互获得一条trajectory:$(s_t, a_t, r_{t+1}, a_{t+1})|_{t=0}^{T-1}$
将此时的策略记为$\pi_{\theta_{old}}$,重复步骤3若干次
最新的策略记为$\pi_\theta$,根据$\pi_\theta$计算熵、trajectory上所有状态的$a_t$、$p_t^{\text{clamp}}$,计算梯度,用平均梯度更新$\pi_\theta$
经过400局self-play后,与《何切三百问》前165题上能取得60%的一致率
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。