别再死记硬背MDP了!用Python手搓一个GridWorld环境,直观理解状态、动作与奖励
用Python构建GridWorld可视化理解强化学习核心概念当第一次接触强化学习时那些抽象的概率公式和数学符号往往让人望而生畏。与其在纸上反复推导马尔可夫决策过程(MDP)的各种方程不如直接动手用代码构建一个简单的环境亲眼看看智能体是如何通过试错来学习最优策略的。本文将带你用Python从零开始实现一个GridWorld环境通过可视化手段直观理解状态、动作、奖励这些核心概念。1. 环境搭建与基础实现1.1 创建网格世界GridWorld是最经典的强化学习教学环境之一它由一个二维网格组成每个格子代表一个不同的状态。我们先定义网格的基本结构import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap class GridWorld: def __init__(self, size5): self.size size # 0:普通格子, 1:障碍物, 2:目标格子, 3:危险格子 self.grid np.zeros((size, size)) self._setup_grid() def _setup_grid(self): # 设置障碍物 self.grid[1, 1] 1 # 设置目标格子 self.grid[size-1, size-1] 2 # 设置危险格子 self.grid[2, 3] 3 def render(self): cmap ListedColormap([white, black, green, red]) plt.imshow(self.grid, cmapcmap) plt.grid(whichmajor, axisboth, linestyle-, colork, linewidth2) plt.xticks(np.arange(-.5, self.size, 1), []) plt.yticks(np.arange(-.5, self.size, 1), []) plt.show()这个基础实现创建了一个5x5的网格其中(1,1)是障碍物黑色(4,4)是目标格子绿色(2,3)是危险格子红色其余是普通格子白色1.2 定义状态与动作在GridWorld中每个格子坐标就是一个状态。我们定义四种基本动作上(0)、右(1)、下(2)、左(3)。class GridWorld(GridWorld): def __init__(self, size5): super().__init__(size) self.actions [0, 1, 2, 3] # 上,右,下,左 self.action_effects [(-1,0), (0,1), (1,0), (0,-1)] self.state (0, 0) # 初始位置 def step(self, action): x, y self.state dx, dy self.action_effects[action] new_x, new_y x dx, y dy # 检查边界和障碍物 if (0 new_x self.size and 0 new_y self.size and self.grid[new_x, new_y] ! 1): self.state (new_x, new_y) # 计算奖励 reward self._get_reward() done self._is_terminal() return self.state, reward, done def _get_reward(self): x, y self.state cell_type self.grid[x, y] if cell_type 2: # 目标格子 return 10 elif cell_type 3: # 危险格子 return -10 else: # 普通格子 return -1 # 每步小惩罚鼓励尽快到达目标 def _is_terminal(self): x, y self.state return self.grid[x, y] in [2, 3] # 到达目标或危险格子结束2. 可视化智能体移动2.1 实时渲染环境为了让学习过程更直观我们增强render方法显示智能体的当前位置class GridWorld(GridWorld): def render(self, q_valuesNone, policyNone): plt.figure(figsize(8, 8)) # 绘制网格 cmap ListedColormap([white, black, green, red]) plt.imshow(self.grid, cmapcmap) # 标记智能体位置 agent_x, agent_y self.state plt.scatter(agent_y, agent_x, cblue, s200, markero) # 可选显示Q值或策略 if q_values is not None: for i in range(self.size): for j in range(self.size): if self.grid[i, j] 0: # 只在空白格子显示 plt.text(j-0.3, i, f{q_values[i,j,0]:.1f}, fontsize8) plt.text(j, i0.3, f{q_values[i,j,1]:.1f}, fontsize8) plt.text(j0.3, i, f{q_values[i,j,2]:.1f}, fontsize8) plt.text(j, i-0.3, f{q_values[i,j,3]:.1f}, fontsize8) plt.grid(whichmajor, axisboth, linestyle-, colork, linewidth2) plt.xticks(np.arange(-.5, self.size, 1), []) plt.yticks(np.arange(-.5, self.size, 1), []) plt.show()2.2 交互式演示让我们创建一个简单的随机策略观察智能体如何移动env GridWorld() for _ in range(20): action np.random.choice(env.actions) # 随机选择动作 state, reward, done env.step(action) env.render() print(fAction: {action}, Reward: {reward}, Done: {done}) if done: break这个演示中智能体会随机选择动作移动直到到达目标或危险格子。通过观察你可以直观看到状态智能体所在的格子坐标动作每次移动的方向奖励不同格子给予的不同反馈终止条件到达特定格子结束回合3. 实现价值函数与策略3.1 状态价值函数计算状态价值函数V(s)表示从状态s出发按照某个策略能获得的期望回报。我们可以通过动态编程来计算def compute_state_values(env, gamma0.9, theta1e-4): V np.zeros((env.size, env.size)) while True: delta 0 for i in range(env.size): for j in range(env.size): if env.grid[i, j] in [1, 2, 3]: # 跳过障碍物和终止状态 continue v 0 # 假设均匀随机策略 for action in env.actions: env.state (i, j) next_state, reward, done env.step(action) next_i, next_j next_state v 0.25 * (reward gamma * V[next_i, next_j]) delta max(delta, abs(v - V[i, j])) V[i, j] v if delta theta: break return V3.2 动作价值函数计算动作价值函数Q(s,a)表示在状态s下采取动作a能获得的期望回报def compute_action_values(env, gamma0.9, theta1e-4): Q np.zeros((env.size, env.size, len(env.actions))) while True: delta 0 for i in range(env.size): for j in range(env.size): if env.grid[i, j] in [1, 2, 3]: continue for a in range(len(env.actions)): env.state (i, j) next_state, reward, done env.step(a) next_i, next_j next_state # 最优策略下取最大Q值 max_q np.max(Q[next_i, next_j]) if not done else 0 new_q reward gamma * max_q delta max(delta, abs(new_q - Q[i, j, a])) Q[i, j, a] new_q if delta theta: break return Q3.3 策略提取与改进有了Q值后我们可以提取最优策略def extract_policy(env, Q): policy np.zeros((env.size, env.size), dtypeint) for i in range(env.size): for j in range(env.size): if env.grid[i, j] in [1, 2, 3]: continue policy[i, j] np.argmax(Q[i, j]) return policy def visualize_policy(env, policy): plt.figure(figsize(8, 8)) cmap ListedColormap([white, black, green, red]) plt.imshow(env.grid, cmapcmap) arrow_map {0: ↑, 1: →, 2: ↓, 3: ←} for i in range(env.size): for j in range(env.size): if env.grid[i, j] 0: # 只在空白格子显示 plt.text(j, i, arrow_map[policy[i, j]], hacenter, vacenter, fontsize20) plt.grid(whichmajor, axisboth, linestyle-, colork, linewidth2) plt.xticks(np.arange(-.5, env.size, 1), []) plt.yticks(np.arange(-.5, env.size, 1), []) plt.show()4. 完整训练流程与结果分析4.1 价值迭代算法实现结合前面的函数我们可以实现完整的价值迭代算法def value_iteration(env, gamma0.9, theta1e-4): Q compute_action_values(env, gamma, theta) policy extract_policy(env, Q) return Q, policy # 训练并可视化结果 env GridWorld() Q, policy value_iteration(env) visualize_policy(env, policy) # 显示某个状态的Q值 state (0, 0) print(fQ-values at state {state}:) print(fUp: {Q[state][0]:.2f}, Right: {Q[state][1]:.2f}) print(fDown: {Q[state][2]:.2f}, Left: {Q[state][3]:.2f})4.2 结果解读与分析运行上述代码后你会看到每个格子上显示最优动作箭头表示初始状态(0,0)的各个动作Q值通过分析这些结果我们可以理解状态价值离目标越近的格子通常有更高的状态价值动作价值朝向目标的动作会有更高的Q值策略形成智能体学会了避开障碍物和危险区域选择最快路径到达目标4.3 参数实验与调优尝试调整以下参数观察效果变化参数默认值影响建议实验值gamma0.9折扣因子影响未来奖励的重要性0.5, 0.9, 0.99危险格子奖励-10控制避开危险区域的强度-5, -10, -20每步惩罚-1鼓励尽快完成任务的程度0, -1, -0.1例如降低gamma值会使智能体更关注即时奖励Q_low_gamma, _ value_iteration(env, gamma0.5) print(fLow gamma Q-values at (0,0): {Q_low_gamma[0,0]})在实际项目中我发现gamma值的选择对策略有显著影响。较高的gamma值(如0.99)会使智能体更有远见但可能需要更多训练时间而较低的gamma值(如0.5)会使智能体更短视可能错过长期更优的路径。