在游戏 AI 领域,强化学习凭借其自主学习能力,实现了从简单小游戏到复杂 3A 大作的智能决策。其中深度 Q 网络(DQN)作为连接深度学习与强化学习的里程碑算法,让 AI 首次在 Atari 游戏中达到人类水平。本文将从基础原理出发,通过实战案例讲解 DQN 算法如何在游戏中实现自主决策。
一、DQN 算法核心原理
DQN(Deep Q-Network)的核心思想是用深度神经网络近似 Q 值函数,解决传统强化学习无法处理高维状态空间的问题。
1. Q-learning 基础
Q-learning 通过学习动作价值函数 Q (s,a) 指导决策,其中 s 是当前状态,a 是动作,Q (s,a) 表示在状态 s 下采取动作 a 的预期累积奖励:
def q_learning_update(q_table, state, action, reward, next_state, alpha=0.1, gamma=0.9): """Q-learning更新公式""" old_value = q_table[state][action] # 选择下一状态的最大Q值 next_max = max(q_table[next_state].values()) # 更新Q值:Q(s,a) = Q(s,a) + α[r + γmaxQ(s',a') - Q(s,a)] new_value = old_value + alpha * (reward + gamma * next_max - old_value) q_table[state][action] = new_value return q_table
传统 Q-learning 使用表格存储 Q 值,仅适用于状态空间较小的场景(如迷宫问题)。
2. DQN 的创新点
DQN 通过两个关键技术解决了深度神经网络在强化学习中的不稳定问题:
- 经验回放(Experience Replay):存储智能体的经验 (s,a,r,s') 到回放池,随机采样打破样本相关性
- 目标网络(Target Network):使用单独的目标网络计算目标 Q 值,定期同步主网络参数
class ReplayBuffer: """经验回放池""" def __init__(self, capacity): self.buffer = deque(maxlen=capacity) # 有限容量的双端队列 def add(self, state, action, reward, next_state, done): """添加经验""" self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): """随机采样批次数据""" transitions = random.sample(self.buffer, batch_size) # 转换为批量数组 state, action, reward, next_state, done = zip(*transitions) return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done) def size(self): """返回当前存储的经验数量""" return len(self.buffer)
经验回放机制使神经网络的训练样本更符合独立同分布假设,显著提升了训练稳定性。
二、DQN 网络结构与实现
DQN 使用卷积神经网络处理游戏画面,输出每个动作的 Q 值估计。
1. 网络结构设计
import torchimport torch.nn as nnimport torch.optim as optimimport numpy as npfrom collections import dequeimport randomclass DQN(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super(DQN, self).__init__() # 适用于低维状态的全连接网络(如贪吃蛇) self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, action_dim) # 输出每个动作的Q值 def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) return self.fc3(x)class DQN_CNN(nn.Module): def __init__(self, action_dim): super(DQN_CNN, self).__init__() # 适用于图像输入的卷积网络(如Atari游戏) self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4) # 4帧灰度图输入 self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) self.fc1 = nn.Linear(7 * 7 * 64, 512) # 卷积输出扁平化 self.fc2 = nn.Linear(512, action_dim) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = torch.relu(self.conv3(x)) x = x.view(x.size(0), -1) # 扁平化 x = torch.relu(self.fc1(x)) return self.fc2(x)
网络输出层维度等于动作空间大小,每个输出值对应该动作的 Q 值估计。
2. DQN 算法实现
class DQNAgent: def __init__(self, state_dim, action_dim, lr=1e-3, gamma=0.99, epsilon=1.0): self.action_dim = action_dim self.gamma = gamma # 折扣因子 self.epsilon = epsilon # 探索率 self.epsilon_min = 0.01 self.epsilon_decay = 0.995 # 探索率衰减 # 主网络和目标网络 self.policy_net = DQN(state_dim, action_dim) self.target_net = DQN(state_dim, action_dim) self.target_net.load_state_dict(self.policy_net.state_dict()) # 初始化一致 self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr) self.buffer = ReplayBuffer(capacity=10000) self.batch_size = 64 def select_action(self, state): """ε-贪婪策略选择动作""" if random.random() < self.epsilon: # 随机探索 return random.randint(0, self.action_dim - 1) else: # 贪婪选择最大Q值动作 with torch.no_grad(): state = torch.FloatTensor(state).unsqueeze(0) q_values = self.policy_net(state) return q_values.max(1)[1].item() def update(self): """从经验回放池中采样并更新网络""" if self.buffer.size() < self.batch_size: return # 经验不足时不更新 # 采样批次数据 states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size) # 转换为张量 states = torch.FloatTensor(states) actions = torch.LongTensor(actions).unsqueeze(1) rewards = torch.FloatTensor(rewards) next_states = torch.FloatTensor(next_states) dones = torch.FloatTensor(dones) # 计算当前Q值:主网络输出的对应动作Q值 current_q = self.policy_net(states).gather(1, actions).squeeze(1) # 计算目标Q值:目标网络的下一状态最大Q值 with torch.no_grad(): next_q = self.target_net(next_states).max(1)[0] target_q = rewards + (1 - dones) * self.gamma * next_q # 计算损失并优化 loss = nn.MSELoss()(current_q, target_q) self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 衰减探索率 if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay return loss.item() def sync_target_network(self): """同步目标网络参数""" self.target_net.load_state_dict(self.policy_net.state_dict())
目标网络定期同步(如每 1000 步),避免 Q 值估计的剧烈波动,提升训练稳定性。
二、实战案例:DQN 玩贪吃蛇
以经典贪吃蛇游戏为例,展示 DQN 如何自主学习避障和吃食物。
1. 游戏环境设计
import pygameimport numpy as npclass SnakeGame: def __init__(self, width=400, height=400, block_size=20): self.width = width self.height = height self.block_size = block_size self.reset() def reset(self): """重置游戏状态""" self.snake = [(100, 100), (80, 100), (60, 100)] # 蛇身体 self.direction = (self.block_size, 0) # 初始向右 self.food = self._generate_food() self.score = 0 self.done = False return self._get_state() def _generate_food(self): """随机生成食物位置(不与蛇身重叠)""" while True: x = random.randint(0, (self.width - self.block_size) // self.block_size) * self.block_size y = random.randint(0, (self.height - self.block_size) // self.block_size) * self.block_size food = (x, y) if food not in self.snake: return food def _get_state(self): """提取游戏状态特征""" head_x, head_y = self.snake[0] # 方向特征 dir_left = self.direction == (-self.block_size, 0) dir_right = self.direction == (self.block_size, 0) dir_up = self.direction == (0, -self.block_size) dir_down = self.direction == (0, self.block_size) # 危险特征(撞墙或自身) danger_straight = ( (dir_right and (head_x + self.block_size >= self.width or (head_x + self.block_size, head_y) in self.snake)) or (dir_left and (head_x - self.block_size < 0 or (head_x - self.block_size, head_y) in self.snake)) or (dir_up and (head_y - self.block_size < 0 or (head_x, head_y - self.block_size) in self.snake)) or (dir_down and (head_y + self.block_size >= self.height or (head_x, head_y + self.block_size) in self.snake)) ) danger_right = ( (dir_up and (head_x + self.block_size >= self.width or (head_x + self.block_size, head_y) in self.snake)) or (dir_down and (head_x - self.block_size < 0 or (head_x - self.block_size, head_y) in self.snake)) or (dir_left and (head_y - self.block_size < 0 or (head_x, head_y - self.block_size) in self.snake)) or (dir_right and (head_y + self.block_size >= self.height or (head_x, head_y + self.block_size) in self.snake)) ) danger_left = ( (dir_down and (head_x + self.block_size >= self.width or (head_x + self.block_size, head_y) in self.snake)) or (dir_up and (head_x - self.block_size < 0 or (head_x - self.block_size, head_y) in self.snake)) or (dir_right and (head_y - self.block_size < 0 or (head_x, head_y - self.block_size) in self.snake)) or (dir_left and (head_y + self.block_size >= self.height or (head_x, head_y + self.block_size) in self.snake)) ) # 食物位置特征 food_up = head_y > self.food[1] food_down = head_y < self.food[1] food_left = head_x > self.food[0] food_right = head_x < self.food[0] # 合并状态特征(11维) return np.array([ danger_straight, danger_right, danger_left, dir_left, dir_right, dir_up, dir_down, food_left, food_right, food_up, food_down ], dtype=int) def step(self, action): """执行动作并返回新状态、奖励、是否结束""" # 动作映射:0=直行,1=右转,2=左转 dx, dy = self.direction if action == 1: # 右转 dx, dy = dy, -dx elif action == 2: # 左转 dx, dy = -dy, dx self.direction = (dx, dy) # 移动蛇头 head_x, head_y = self.snake[0] new_head = (head_x + dx, head_y + dy) # 检查碰撞 if (new_head[0] < 0 or new_head[0] >= self.width or new_head[1] < 0 or new_head[1] >= self.height or new_head in self.snake): self.done = True reward = -10 # 碰撞惩罚 else: # 移动蛇身 self.snake.insert(0, new_head) # 检查是否吃到食物 if new_head == self.food: self.score += 1 reward = 10 # 吃到食物奖励 self.food = self._generate_food() else: self.snake.pop() # 没吃到食物则移除尾部 reward = 0 # 普通步骤无奖励 return self._get_state(), reward, self.done
游戏环境将蛇的位置、方向、食物位置和危险信息编码为 11 维状态向量,作为 DQN 的输入。
2. 训练过程
def train_snake(episodes=500): """训练贪吃蛇AI""" env = SnakeGame() state_dim = 11 # 状态特征维度 action_dim = 3 # 直行、右转、左转 agent = DQNAgent(state_dim, action_dim) scores = [] target_update_freq = 10 # 每10回合更新一次目标网络 for episode in range(episodes): state = env.reset() total_reward = 0 done = False while not done: # 选择动作 action = agent.select_action(state) # 执行动作 next_state, reward, done = env.step(action) # 存储经验 agent.buffer.add(state, action, reward, next_state, done) # 更新网络 agent.update() state = next_state total_reward += reward scores.append(env.score) # 衰减探索率 if agent.epsilon > agent.epsilon_min: agent.epsilon *= agent.epsilon_decay # 定期更新目标网络 if episode % target_update_freq == 0: agent.target_net.load_state_dict(agent.policy_net.state_dict()) # 打印进度 if (episode + 1) % 10 == 0: avg_score = np.mean(scores[-10:]) print(f"回合 {episode+1}/{episodes}, 分数: {env.score}, 平均分数: {avg_score:.1f}, ε: {agent.epsilon:.3f}") # 保存模型 torch.save(agent.policy_net.state_dict(), "snake_dqn.pth") return scores# 开始训练scores = train_snake()# 绘制训练曲线import matplotlib.pyplot as pltplt.plot(scores)plt.xlabel("回合")plt.ylabel("分数")plt.title("贪吃蛇DQN训练曲线")plt.show()
训练过程中,AI 从随机碰撞逐步学会避障、追踪食物,分数随回合数稳步提升,体现了 DQN 的自主学习能力。
三、DQN 在复杂游戏中的应用
对于 Atari 等像素级游戏,需要使用 CNN 处理图像输入,并增加预处理步骤:
def preprocess_frame(frame): """预处理游戏画面""" # 转为灰度图 gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) # 缩小尺寸 resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA) # 二值化 _, binary = cv2.threshold(resized, 1, 255, cv2.THRESH_BINARY) # 归一化 return binary / 255.0class AtariAgent(DQNAgent): """适用于Atari游戏的DQN智能体""" def __init__(self, action_dim): # 使用CNN网络 self.policy_net = DQN_CNN(action_dim) self.target</doubaocanvas>