SARSA (State-Action-Reward-State-Action) is an on-policy reinforcement learning algorithm that updates its policy based on the current state-action pair, the reward received, the next state, and the next action chosen by the current policy.
In this article, we will implement SARSA in Gymnasium’s Taxi-v3 environment, walking through the setup, agent definition, training, and visualization of the agent’s learning process.
Table of Content
- SARSA Learning
- Key Concepts in SARSA
- SARSA Update Rule
- Implementing SARSA in Gymnasium’s Taxi-v3 Environment
- Step 1: Setup and Initialization
- Step 2: Define the SARSA Agent
- Step 3: Training the Agent
- Step 4: Visualization of Learning Progress
- Step 5: Running the Trained Agent
- Output Explanation
- Conclusion
SARSA Learning
SARSA is a temporal difference (TD) learning algorithm that combines ideas from dynamic programming and Monte Carlo methods. The key feature of SARSA is that it learns the Q-value based on the action taken by the current policy, making it an on-policy method. This means that the agent follows its current policy both when selecting actions to take and when updating the Q-values.
Key Concepts in SARSA
- State (S): The current situation or configuration the agent is in.
- Action (A): The decision or move the agent takes in a given state.
- Reward (R): The immediate gain or loss the agent receives after taking an action in a state.
- Next State (S’): The new state the agent transitions to after taking an action.
- Next Action (A’): The action the agent takes in the next state according to its policy.
SARSA Update Rule
The SARSA update rule for Q-values is given by:
Where:
is the learning rate.
is the discount factor.
is the Q-value of the current state-action pair.
is the reward received after taking action
in state
.
is the Q-value of the next state-action pair.
Implementing SARSA in Gymnasium’s Taxi-v3 Environment
We’ll walk through the process of setting up the environment, defining and implementing a SARSA-based learning agent, training this agent, and finally, visualizing its learning progress and performance. Each step is crucial for understanding how SARSA, an on-policy reinforcement learning algorithm, updates its policies based on the actions taken and the rewards received, contrasting with other methods like Q-learning that are off-policy and may not consider the current policy’s influence on the action’s outcome.
Before proceeding with the implementation make sure you have installed Gymnasium.
Gymnasium is the continuation and evolution of the popular OpenAI Gym environment for developing and comparing reinforcement learning algorithms, you can use Python’s package manager pip
.
Here’s how you can install Gymnasium on your system:
pip install gymnasium
Step 1: Setup and Initialization
First, we begin by importing necessary libraries and defining a plotting function that we will use later to visualize the agent’s performance over training episodes.
import gymnasium as gymimport numpy as npfrom collections import defaultdictimport matplotlib.pyplot as pltdef plot_returns(returns): plt.plot(np.arange(len(returns)), returns) plt.title('Episode returns') plt.xlabel('Episode') plt.ylabel('Return') plt.show()
Step 2: Define the SARSA Agent
Next, we define the SARSAAgent
class. This agent initializes with a set of parameters that dictate its learning and decision-making processes. It also includes methods for selecting actions, updating Q-values, and adjusting the exploration rate.
class SARSAAgent: def __init__(self, env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon, discount_factor=0.95): self.env = env self.learning_rate = learning_rate self.discount_factor = discount_factor self.epsilon = initial_epsilon self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon self.q_values = defaultdict(lambda: np.zeros(env.action_space.n)) def get_action(self, obs) -> int: if np.random.rand() < self.epsilon: return self.env.action_space.sample() # Explore else: return np.argmax(self.q_values[obs]) # Exploit def update(self, obs, action, reward, terminated, next_obs, next_action): if not terminated: td_target = reward + self.discount_factor * self.q_values[next_obs][next_action] td_error = td_target - self.q_values[obs][action] self.q_values[obs][action] += self.learning_rate * td_error def decay_epsilon(self): self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)
Step 3: Training the Agent
With the SARSA agent defined, we proceed to train it across multiple episodes. The training function loops over each episode, allowing the agent to interact with the environment, learn from actions, and gradually improve its policy.
def train_agent(agent, env, episodes, eval_interval=100): rewards = [] for i in range(episodes): obs, _ = env.reset() terminated = truncated = False total_reward = 0 while not terminated and not truncated: action = agent.get_action(obs) next_obs, reward, terminated, truncated, _ = env.step(action) next_action = agent.get_action(next_obs) agent.update(obs, action, reward, terminated, next_obs, next_action) obs = next_obs action = next_action total_reward += reward agent.decay_epsilon() rewards.append(total_reward) if i % eval_interval == 0 and i > 0: avg_return = np.mean(rewards[max(0, i - eval_interval):i]) print(f"Episode {i} -> Average Return: {avg_return}") return rewards
Step 4: Visualization of Learning Progress
After training, it’s beneficial to visualize the learning progress. We use the plot_returns
function to display the returns per episode, offering insights into the effectiveness of our training regimen.
env = gym.make('Taxi-v3', render_mode='ansi')episodes = 20000learning_rate = 0.5initial_epsilon = 1final_epsilon = 0epsilon_decay = (final_epsilon - initial_epsilon) / (episodes / 2)agent = SARSAAgent(env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon)returns = train_agent(agent, env, episodes)plot_returns(returns)
Step 5: Running the Trained Agent
Finally, to see our trained agent in action, we can run it in the environment, now with a reduced exploration rate as the agent should have learned a near-optimal policy.
def run_agent(agent, env): agent.epsilon = 0 # No need to keep exploring obs, _ = env.reset() env.render() terminated = truncated = False while not terminated and not truncated: action = agent.get_action(obs) next_obs, _, terminated, truncated, _ = env.step(action) print(env.render())
Complete Code to implement Implementing SARSA in Gymnasium’s Taxi-v3 Environment
import gymnasium as gymimport numpy as npfrom collections import defaultdictimport matplotlib.pyplot as plt# Define the plotting function early in the scriptdef plot_returns(returns): plt.plot(np.arange(len(returns)), returns) plt.title('Episode returns') plt.xlabel('Episode') plt.ylabel('Return') plt.show()class SARSAAgent: def __init__(self, env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon, discount_factor=0.95): self.env = env self.learning_rate = learning_rate self.discount_factor = discount_factor self.epsilon = initial_epsilon self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon self.q_values = defaultdict(lambda: np.zeros(env.action_space.n)) def get_action(self, obs) -> int: if np.random.rand() < self.epsilon: return self.env.action_space.sample() else: return np.argmax(self.q_values[obs]) def update(self, obs, action, reward, terminated, next_obs, next_action): if not terminated: td_target = reward + self.discount_factor * self.q_values[next_obs][next_action] td_error = td_target - self.q_values[obs][action] self.q_values[obs][action] += self.learning_rate * td_error def decay_epsilon(self): """Decrease the exploration rate epsilon until it reaches its final value""" self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)def train_agent(agent, env, episodes, eval_interval=100): rewards = [] best_reward = -np.inf for i in range(episodes): obs, _ = env.reset() terminated = truncated = False total_reward = 0 # Initialize total_reward correctly here while not terminated and not truncated: action = agent.get_action(obs) next_obs, reward, terminated, truncated, _ = env.step(action) next_action = agent.get_action(next_obs) # Get the next action using current policy agent.update(obs, action, reward, terminated, next_obs, next_action) obs = next_obs action = next_action total_reward += reward # Use the correct variable name here agent.decay_epsilon() rewards.append(total_reward) if i % eval_interval == 0 and i > 0: avg_return = np.mean(rewards[max(0, i - eval_interval):i]) best_reward = max(avg_return, best_reward) print(f"Episode {i} -> best_reward={best_reward}") return rewards# Initialize the environment and agentenv = gym.make('Taxi-v3', render_mode='ansi')episodes = 20000learning_rate = 0.5discount_factor = 0.95initial_epsilon = 1final_epsilon = 0epsilon_decay = ((final_epsilon - initial_epsilon) / (episodes/2))agent = SARSAAgent(env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon)# Train the agent and plot the returnsreturns = train_agent(agent, env, episodes)plot_returns(returns)def plot_returns(returns): plt.plot(np.arange(len(returns)), returns) plt.title('Episode returns') plt.xlabel('Episode') plt.ylabel('Return') plt.show()def run_agent(agent, env): agent.epsilon = 0 # No need to keep exploring obs, _ = env.reset() env.render() terminated = truncated = False while not terminated and not truncated: action = agent.get_action(obs) next_obs, _, terminated, truncated, _ = env.step(action) print(env.render()) obs = next_obsenv = gym.make('Taxi-v3', render_mode='ansi')run_agent(agent, env)
Output:
Episode 100 -> best_reward=-778.52Episode 200 -> best_reward=-757.13Episode 300 -> best_reward=-757.13Episode 400 -> best_reward=-757.13Episode 500 -> best_reward=-757.13..Episode 19500 -> best_reward=-743.8Episode 19600 -> best_reward=-743.8Episode 19700 -> best_reward=-743.8Episode 19800 -> best_reward=-743.8Episode 19900 -> best_reward=-743.8
Episode vs Return Graph
Taxi V3 Environment
Output Explanation
The output of the code consists of two main parts:
- Training Progress and Episode Returns Plot:
- During training, the SARSA agent’s performance is periodically evaluated, and the average return is printed every
eval_interval
episodes. - After training, a plot of episode returns over time is displayed, showing how the agent’s performance changes as it learns from more episodes.
- During training, the SARSA agent’s performance is periodically evaluated, and the average return is printed every
- Agent’s Behavior Demonstration:
- After training, the
run_agent
function demonstrates the agent’s behavior in the “Taxi-v3” environment. The environment’s state is printed to the console at each step, showing the agent’s decisions and movements.
- After training, the
Returns vs Episode Plot: At the end of training, the plot_returns
function generates a plot showing the total return for each episode. The x-axis represents the episode number, and the y-axis represents the return (total reward) for that episode. This plot helps visualize the learning curve of the agent, showing trends such as improvements, plateaus, or fluctuations in performance.
Demonstration of the Output Grid:
- The sequence of diagrams shows how the agent navigates through the grid, with each step representing a move or a turn.
- The agent’s path is defined by its movement and orientation changes, aiming to reach a goal point (G) or other significant points (R and B) in the grid.
- The specific orientations (North, East, etc.) are crucial for understanding the agent’s strategy or algorithm for navigating the grid.
Conclusion
Implementing a SARSA agent in the Gymnasium’s Taxi-v3 environment provides a hands-on approach to understanding on-policy reinforcement learning algorithms. Through setting up the environment, defining the agent, training, and visualizing its progress, we gain valuable insights into how SARSA updates its policies based on current actions and their outcomes.
a0238rtec
Improve
Previous Article
Shortest Path Problem Between Routing Terminals - Implementation in Python
Next Article
How to Change Your Career from Marketing to Data Science?