使用增强学习方法解魔方问题
A more clear jupyter notebook version
http://nbviewer.jupyter.org/github/cinqs/rubikcubesolverbyrl/blob/master/workbench.ipynb
A glance of the recent version
http://nbviewer.jupyter.org/github/cinqs/rubikcubesolverbyrl/blob/master/Untitled.ipynb
of course the code
https://github.com/cinqs/rubikcubesolverbyrl
the dumped Q table(with format tricked)
https://github.com/cinqs/rubikcubesolverbyrl/blob/master/Q.json
添加必须的依赖
from pycuber_sc import Cube # 经过部分修改的 pycuber, 参见https://github.com/cinqs/rubikcubesolverbyrl
import numpy as np
import tensorflow as tf
from collections import deque
import itertools
import sys
from collections import defaultdict
import json
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [30, 10]
import pandas as pd
定义action_space
# read this http://rubiks.wikia.com/wiki/Notation
action_space = ['u', 'd', 'l', 'r', 'f', 'b', 'u\'', 'd\'', 'l\'', 'r\'', 'f\'', 'b\'']
这是12个可以执行的动作,分别对应了一个面的顺时针和逆时针旋转。于是动作空间的大小为12
定义环境
class Env:
def __init__(self, tweaked_times_limit=100, tweaked_times=1):
"""
@param tweaked_times_limit 放弃旋转之前尝试的次数
@param tweaked_times 初始化一个魔方前旋转的次数
"""
self._cube = Cube()
# 随机初始化一个魔方
for _ in range(tweaked_times):
self._cube(str(np.random.choice(action_space)))
self.reset()
self.nA = len(action_space)
self.nS = len(self._get_state())
self._tweaked_times_limit = tweaked_times_limit
def reset(self):
"""
重置这个环境的时候将魔方恢复到初始状态
"""
self.cube = self._cube.copy()
self.tweaked_times = 0
return self._get_state()
def step(self, action):
self.cube(str(action_space[action]))
self.tweaked_times += 1
done = self.cube.check() or self.tweaked_times > self._tweaked_times_limit #完成或者放弃
reward = -0.1 if not self.cube.check() else 1 # -0.1是生存的代价,负值保证了智能体不会在收集芝麻的道路上乐此不疲
return self._get_state(), reward, done, None
def _get_state(self):
"""
将状态归一化之后,重组为元组类型
"""
return tuple(s / 5. for s in self.cube.get_state())
定义训练办法(Sarsa)
先按照Sarsa的办法训练
Q = defaultdict(lambda :np.zeros(len(action_space)))
#epsilon = .05
epsilon = 0
def get_action(state):
"""
保证策略可以在贪婪性上做出变化
此处的贪婪系数为0,是因为在这种环境下,贪婪策略不会导致智能体永远选择 次好的 动作(随着魔方初始旋转次数的增加,可以考虑改变贪婪系数探索更优策略的可能)
"""
probs = np.ones(nA) * (epsilon / nA)
probs[np.argmax(Q[state])] += (1. - epsilon)
return np.random.choice(np.arange(nA), p=probs)
for n in range(0, 6): # n代表了初始化的魔方的旋转次数(如果直接选择较大的旋转次数,会增加训练的智能体的迷茫)
for i in range(5 * (6 ** n)): # 5 * 6 ** n 稍微加强了智能体遍历各种魔方状态的可能(但是泛化能力几乎为零,这种方法类似于MC方法的暴力美学)
# implement Sarsa with one Q in the above cell
env = Env(tweaked_times=n, tweaked_times_limit=n * 2)
num_episode = 100 # 一种初始状态下的最多的尝试次数
discounter = 0.9 # temporal difference (TD) 中的折扣系数 lambda 参见 贝尔曼方程 bellman equation https://en.wikipedia.org/wiki/Bellman_equation
sum_rewards = deque(maxlen=30)
avg_sum_rewards = []
nA = env.nA
for episode_idx in range(num_episode):
state = env.reset()
action = get_action(state)
td_errors = 0 # Q-V 的误差累积值
rewards = 0
for t in itertools.count():
state_prime, reward, done, _ = env.step(action)
rewards += reward
action_prime = get_action(state_prime)
td_error = reward + discounter * Q[state_prime][action_prime] - Q[state][action] # 参见 https://en.wikipedia.org/wiki/Temporal_difference_learning
Q[state][action] += td_error * 0.3 # 0.3是学习率 eta
td_errors += np.abs(td_error)
if done:
break
else:
state = state_prime
action = action_prime
sum_rewards.append(rewards)
avg_sum_rewards.append(np.mean(sum_rewards))
print('\r {} {} {} {:>30}'.format(n, i, episode_idx, td_errors), end='')
if td_errors < 10 ** -5: # 如果累积的误差已经很小的话,没有必要再持续尝试(state value近乎收敛)
break
项目最近还在慢慢的进行中
假如说直接将魔方旋转5次以上,允许RL使用MC方法直接暴力破解然后学习,这样子见效太慢。
我在这个程序内使用的循序渐进的办法:
将模仿旋转1次,然后让RL学习如何解一次旋转的魔方。然后旋转两次,RL会学习解第二次旋转的问题,依次渐进,将一个复杂的见效慢的问题逐步解决
使用Q-table的问题:
Qtable的好处也是见效比较快,实现起来比较简单。
但是问题是,当状态数量变得逐渐多的时候,例如魔方旋转5次。Qtable jsonized的文件大小达到了500MB以上。
所以下一步就要使用神经网络来代替Q-table
决定采用A2C模式
使用Q-table生成两份训练数据,一份数据用来训练Critic,数据包含状态(X),和各动作的Value(Y) 一份数据用来训练Actor, 数据包含状态(X), 和从各动作的value中提取出来的执行的动作(Y one-hoted)
这样训练出来的Actor, Critic就是两个预训练的网络,应该训练完成之后就可以处理五次旋转的魔方
A glance
http://nbviewer.jupyter.org/github/cinqs/rubikcubesolverbyrl/blob/master/Untitled.ipynb
of course the code
https://github.com/cinqs/rubikcubesolverbyrl
the dumped Q table(with format tricked)
https://github.com/cinqs/rubikcubesolverbyrl/blob/master/Q.json
Thanks for the project of Adrian Liaw