Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (2024)

SARSA (State-Action-Reward-State-Action) is an on-policy reinforcement learning algorithm that updates its policy based on the current state-action pair, the reward received, the next state, and the next action chosen by the current policy.

In this article, we will implement SARSA in Gymnasium’s Taxi-v3 environment, walking through the setup, agent definition, training, and visualization of the agent’s learning process.

Table of Content

  • SARSA Learning
    • Key Concepts in SARSA
    • SARSA Update Rule
  • Implementing SARSA in Gymnasium’s Taxi-v3 Environment
    • Step 1: Setup and Initialization
    • Step 2: Define the SARSA Agent
    • Step 3: Training the Agent
    • Step 4: Visualization of Learning Progress
    • Step 5: Running the Trained Agent
    • Output Explanation
  • Conclusion

SARSA Learning

SARSA is a temporal difference (TD) learning algorithm that combines ideas from dynamic programming and Monte Carlo methods. The key feature of SARSA is that it learns the Q-value based on the action taken by the current policy, making it an on-policy method. This means that the agent follows its current policy both when selecting actions to take and when updating the Q-values.

Key Concepts in SARSA

  • State (S): The current situation or configuration the agent is in.
  • Action (A): The decision or move the agent takes in a given state.
  • Reward (R): The immediate gain or loss the agent receives after taking an action in a state.
  • Next State (S’): The new state the agent transitions to after taking an action.
  • Next Action (A’): The action the agent takes in the next state according to its policy.

SARSA Update Rule

The SARSA update rule for Q-values is given by:

Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (1)


  • Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (2) is the learning rate.
  • Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (3) is the discount factor.
  • Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (4) is the Q-value of the current state-action pair.
  • Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (5) is the reward received after taking actionReinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (6) in state Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (7).
  • Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (8) is the Q-value of the next state-action pair.

Implementing SARSA in Gymnasium’s Taxi-v3 Environment

We’ll walk through the process of setting up the environment, defining and implementing a SARSA-based learning agent, training this agent, and finally, visualizing its learning progress and performance. Each step is crucial for understanding how SARSA, an on-policy reinforcement learning algorithm, updates its policies based on the actions taken and the rewards received, contrasting with other methods like Q-learning that are off-policy and may not consider the current policy’s influence on the action’s outcome.

Before proceeding with the implementation make sure you have installed Gymnasium.

Gymnasium is the continuation and evolution of the popular OpenAI Gym environment for developing and comparing reinforcement learning algorithms, you can use Python’s package manager pip.

Here’s how you can install Gymnasium on your system:

pip install gymnasium

Step 1: Setup and Initialization

First, we begin by importing necessary libraries and defining a plotting function that we will use later to visualize the agent’s performance over training episodes.

import gymnasium as gymimport numpy as npfrom collections import defaultdictimport matplotlib.pyplot as pltdef plot_returns(returns): plt.plot(np.arange(len(returns)), returns) plt.title('Episode returns') plt.xlabel('Episode') plt.ylabel('Return')

Step 2: Define the SARSA Agent

Next, we define the SARSAAgent class. This agent initializes with a set of parameters that dictate its learning and decision-making processes. It also includes methods for selecting actions, updating Q-values, and adjusting the exploration rate.

class SARSAAgent: def __init__(self, env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon, discount_factor=0.95): self.env = env self.learning_rate = learning_rate self.discount_factor = discount_factor self.epsilon = initial_epsilon self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon self.q_values = defaultdict(lambda: np.zeros(env.action_space.n)) def get_action(self, obs) -> int: if np.random.rand() < self.epsilon: return self.env.action_space.sample() # Explore else: return np.argmax(self.q_values[obs]) # Exploit def update(self, obs, action, reward, terminated, next_obs, next_action): if not terminated: td_target = reward + self.discount_factor * self.q_values[next_obs][next_action] td_error = td_target - self.q_values[obs][action] self.q_values[obs][action] += self.learning_rate * td_error def decay_epsilon(self): self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

Step 3: Training the Agent

With the SARSA agent defined, we proceed to train it across multiple episodes. The training function loops over each episode, allowing the agent to interact with the environment, learn from actions, and gradually improve its policy.

def train_agent(agent, env, episodes, eval_interval=100): rewards = [] for i in range(episodes): obs, _ = env.reset() terminated = truncated = False total_reward = 0 while not terminated and not truncated: action = agent.get_action(obs) next_obs, reward, terminated, truncated, _ = env.step(action) next_action = agent.get_action(next_obs) agent.update(obs, action, reward, terminated, next_obs, next_action) obs = next_obs action = next_action total_reward += reward agent.decay_epsilon() rewards.append(total_reward) if i % eval_interval == 0 and i > 0: avg_return = np.mean(rewards[max(0, i - eval_interval):i]) print(f"Episode {i} -> Average Return: {avg_return}") return rewards

Step 4: Visualization of Learning Progress

After training, it’s beneficial to visualize the learning progress. We use the plot_returns function to display the returns per episode, offering insights into the effectiveness of our training regimen.

env = gym.make('Taxi-v3', render_mode='ansi')episodes = 20000learning_rate = 0.5initial_epsilon = 1final_epsilon = 0epsilon_decay = (final_epsilon - initial_epsilon) / (episodes / 2)agent = SARSAAgent(env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon)returns = train_agent(agent, env, episodes)plot_returns(returns)

Step 5: Running the Trained Agent

Finally, to see our trained agent in action, we can run it in the environment, now with a reduced exploration rate as the agent should have learned a near-optimal policy.

def run_agent(agent, env): agent.epsilon = 0 # No need to keep exploring obs, _ = env.reset() env.render() terminated = truncated = False while not terminated and not truncated: action = agent.get_action(obs) next_obs, _, terminated, truncated, _ = env.step(action) print(env.render()) 

Complete Code to implement Implementing SARSA in Gymnasium’s Taxi-v3 Environment

import gymnasium as gymimport numpy as npfrom collections import defaultdictimport matplotlib.pyplot as plt# Define the plotting function early in the scriptdef plot_returns(returns): plt.plot(np.arange(len(returns)), returns) plt.title('Episode returns') plt.xlabel('Episode') plt.ylabel('Return') SARSAAgent: def __init__(self, env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon, discount_factor=0.95): self.env = env self.learning_rate = learning_rate self.discount_factor = discount_factor self.epsilon = initial_epsilon self.epsilon_decay = epsilon_decay self.final_epsilon = final_epsilon self.q_values = defaultdict(lambda: np.zeros(env.action_space.n)) def get_action(self, obs) -> int: if np.random.rand() < self.epsilon: return self.env.action_space.sample() else: return np.argmax(self.q_values[obs]) def update(self, obs, action, reward, terminated, next_obs, next_action): if not terminated: td_target = reward + self.discount_factor * self.q_values[next_obs][next_action] td_error = td_target - self.q_values[obs][action] self.q_values[obs][action] += self.learning_rate * td_error def decay_epsilon(self): """Decrease the exploration rate epsilon until it reaches its final value""" self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)def train_agent(agent, env, episodes, eval_interval=100): rewards = [] best_reward = -np.inf for i in range(episodes): obs, _ = env.reset() terminated = truncated = False total_reward = 0 # Initialize total_reward correctly here while not terminated and not truncated: action = agent.get_action(obs) next_obs, reward, terminated, truncated, _ = env.step(action) next_action = agent.get_action(next_obs) # Get the next action using current policy agent.update(obs, action, reward, terminated, next_obs, next_action) obs = next_obs action = next_action total_reward += reward # Use the correct variable name here agent.decay_epsilon() rewards.append(total_reward) if i % eval_interval == 0 and i > 0: avg_return = np.mean(rewards[max(0, i - eval_interval):i]) best_reward = max(avg_return, best_reward) print(f"Episode {i} -> best_reward={best_reward}") return rewards# Initialize the environment and agentenv = gym.make('Taxi-v3', render_mode='ansi')episodes = 20000learning_rate = 0.5discount_factor = 0.95initial_epsilon = 1final_epsilon = 0epsilon_decay = ((final_epsilon - initial_epsilon) / (episodes/2))agent = SARSAAgent(env, learning_rate, initial_epsilon, epsilon_decay, final_epsilon)# Train the agent and plot the returnsreturns = train_agent(agent, env, episodes)plot_returns(returns)def plot_returns(returns): plt.plot(np.arange(len(returns)), returns) plt.title('Episode returns') plt.xlabel('Episode') plt.ylabel('Return') run_agent(agent, env): agent.epsilon = 0 # No need to keep exploring obs, _ = env.reset() env.render() terminated = truncated = False while not terminated and not truncated: action = agent.get_action(obs) next_obs, _, terminated, truncated, _ = env.step(action) print(env.render()) obs = next_obsenv = gym.make('Taxi-v3', render_mode='ansi')run_agent(agent, env)


Episode 100 -> best_reward=-778.52Episode 200 -> best_reward=-757.13Episode 300 -> best_reward=-757.13Episode 400 -> best_reward=-757.13Episode 500 -> best_reward=-757.13..Episode 19500 -> best_reward=-743.8Episode 19600 -> best_reward=-743.8Episode 19700 -> best_reward=-743.8Episode 19800 -> best_reward=-743.8Episode 19900 -> best_reward=-743.8

Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (9)

Episode vs Return Graph

Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (10)

Taxi V3 Environment

Output Explanation

The output of the code consists of two main parts:

  1. Training Progress and Episode Returns Plot:
    • During training, the SARSA agent’s performance is periodically evaluated, and the average return is printed every eval_interval episodes.
    • After training, a plot of episode returns over time is displayed, showing how the agent’s performance changes as it learns from more episodes.
  2. Agent’s Behavior Demonstration:
    • After training, the run_agent function demonstrates the agent’s behavior in the “Taxi-v3” environment. The environment’s state is printed to the console at each step, showing the agent’s decisions and movements.

Returns vs Episode Plot: At the end of training, the plot_returns function generates a plot showing the total return for each episode. The x-axis represents the episode number, and the y-axis represents the return (total reward) for that episode. This plot helps visualize the learning curve of the agent, showing trends such as improvements, plateaus, or fluctuations in performance.

Demonstration of the Output Grid:

  • The sequence of diagrams shows how the agent navigates through the grid, with each step representing a move or a turn.
  • The agent’s path is defined by its movement and orientation changes, aiming to reach a goal point (G) or other significant points (R and B) in the grid.
  • The specific orientations (North, East, etc.) are crucial for understanding the agent’s strategy or algorithm for navigating the grid.


Implementing a SARSA agent in the Gymnasium’s Taxi-v3 environment provides a hands-on approach to understanding on-policy reinforcement learning algorithms. Through setting up the environment, defining the agent, training, and visualizing its progress, we gain valuable insights into how SARSA updates its policies based on current actions and their outcomes.




Previous Article

Shortest Path Problem Between Routing Terminals - Implementation in Python

Next Article

How to Change Your Career from Marketing to Data Science?

Please Login to comment...

Reinforcement Learning in Python: Implementing SARSA Agent in Taxi Environment - GeeksforGeeks (2024)


Top Articles
Latest Posts
Article information

Author: Lakeisha Bayer VM

Last Updated:

Views: 6113

Rating: 4.9 / 5 (49 voted)

Reviews: 88% of readers found this page helpful

Author information

Name: Lakeisha Bayer VM

Birthday: 1997-10-17

Address: Suite 835 34136 Adrian Mountains, Floydton, UT 81036

Phone: +3571527672278

Job: Manufacturing Agent

Hobby: Skimboarding, Photography, Roller skating, Knife making, Paintball, Embroidery, Gunsmithing

Introduction: My name is Lakeisha Bayer VM, I am a brainy, kind, enchanting, healthy, lovely, clean, witty person who loves writing and wants to share my knowledge and understanding with you.