SARSA
SARSA is a reinforcement learning algorithm that is used to learn the optimal action-selection policy for a given Markov Decision Process (MDP). It is an "on-policy" learning algorithm, which means that it learns the optimal policy while following it.
The basic idea behind SARSA is to learn an estimate of the state-action value function, or Q-function, for a given MDP. The Q-function is a mathematical function that takes in a state and an action, and returns the expected long-term reward for taking that action in that state. The goal of SARSA is to find the Q-function that will maximize the expected long-term reward for a given MDP.
To do this, SARSA uses an iterative update rule to improve the estimate of the Q-function over time. At each iteration, the algorithm takes an action in the current state, observes the resulting reward and next state, and then updates the Q-function estimate using the observed reward and the expected long-term reward for the next action chosen using the current Q-function estimate. This update rule is applied repeatedly until the Q-function converges to the optimal Q-function for the MDP.
One key advantage of SARSA is that it is an "on-policy" learning algorithm, which means that it learns the optimal policy while following it. This can make it more stable and easier to converge to the optimal policy compared to "off-policy" learning algorithms, which can sometimes experience instability or divergence. Additionally, SARSA is a model-based algorithm, which means that it requires a model of the environment in order to learn the optimal policy. This can make it more efficient in environments where the dynamics are known or can be accurately modeled.
Overall, SARSA is a powerful and widely-used reinforcement learning algorithm that is well-suited to a variety of environments and tasks. It is simple to implement and has been applied successfully to a wide range of real-world problems, including robot control, game playing, and autonomous vehicles.
import numpy as np
import os
def plot_state_value_table(table, cols):
for idx, state in enumerate(table):
print(f"{table[state]:.2f}", end="\t")
if (idx+1) % cols == 0:
print()
def plot_q_table(table, cols):
for r in range(5):
for c in range(5):
q_values = np.round(table[r][c],2)
print(q_values, end="\t")
print()
# clear the terminal
# TODO: this function run on linux only.
# for windows the clear must be replaced with cls
clear = lambda: os.system('clear')
import numpy as np
class Action():
def __init__(self):
self.action_space = {'U': (-1, 0), 'D': (1, 0),\
'L': (0, -1), 'R': (0, 1)}
self.possible_actions = list(self.action_space.keys())
self.action_n = len(self.possible_actions)
self.action_idxs = range(self.action_n)
def get_action_by_idx(self, idx):
return self.action_space[self.possible_actions[idx]]
def get_action_by_key(self, key):
return self.action_space[key]
class GridWorld():
def __init__(self, shape=(3, 3), obstacles=[], terminal=None,\
agent_pos=np.array([0, 0])):
self.action = Action()
self.shape = shape
self.rows = shape[0]
self.cols = shape[1]
self.obstacles = obstacles
self.agent_pos = agent_pos
self.agent_init_pos = agent_pos
if terminal is None:
self.terminal_state = (self.rows-1, self.cols-1)
else:
self.terminal_state = terminal
self.done = False
self.add_state_action = lambda state, action : \
tuple(np.array(state) + np.array(self.action.get_action_by_idx(action)))
self.is_same_state = lambda s1, s2 : s1[0] == s2[0] and s1[1] == s2[1]
def is_obstacle(self, state):
for s in self.obstacles:
if self.is_same_state(s, state):
return True
return False
def is_terminal(self, s):
return self.is_same_state(s, self.terminal_state)
def is_edge(self, state):
if state[0] < 0 or state[0] > self.rows -1 \
or state[1] < 0 or state[1] > self.cols -1:
return True
return False
def set_agent_pos(self, state):
self.agent_pos = state
def get_agent_pos(self):
return self.agent_pos
def step(self, action):
# agent location is current agent location + action
state = self.get_agent_pos()
tmp_state = self.add_state_action(state, action)
#print(f"tmp_state{tmp_state}")
if self.is_obstacle(tmp_state):
# print("OBSTACLES")
pass
elif self.is_terminal(tmp_state):
self.set_agent_pos(tmp_state)
#print(f"terminal_state:{tmp_state}")
self.done = True
#print("Done")
elif self.is_edge(tmp_state):
# print("Edge")
pass
else:
self.set_agent_pos(tmp_state)
reward = -1 if not self.done else 0
return self.get_agent_pos(), reward, self.done, None
def simulated_step(self, state, action):
if self.is_terminal(state):
#state = next_state
return state, 0, True, None
next_state = self.add_state_action(state, action)
if self.is_obstacle(next_state) or self.is_edge(next_state):
pass
else:
state = next_state
return state, -1, False, None
def reset(self):
self.set_agent_pos(self.agent_init_pos)
self.done = False
return self.agent_init_pos
def render(self):
for r in range(self.rows):
for c in range(self.cols):
state = np.array((r, c))
if self.is_terminal(state):
if self.done:
print('[O]', end="\t")
else:
print('[]', end="\t")
elif self.is_same_state(self.get_agent_pos(), state):
print('O', end="\t")
elif self.is_obstacle(state):
print('X', end="\t")
else:
print('-', end="\t")
print()
if __name__ == '__main__':
env = GridWorld(shape=(5, 5), obstacles=((0, 1), (1, 1)))
env.render()
env.step(0) # 0 -> up
env.step(1) # 1 -> down
env.step(2) # 2 -> left
env.render()
import numpy as np
import time
from env import GridWorld
from utils import *
class Sarsa():
def __init__(self, env, episodes=10000, epsilon=.2, alpha=.1, gamma=.99):
self.action_values = np.zeros((env.rows, env.cols, env.action.action_n))
self.episodes = episodes
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
def policy(self, state):
if np.random.random() < self.epsilon:
return np.random.choice(env.action.action_idxs)
else:
av = self.action_values [state[0]][state[1]]
return np.random.choice(np.flatnonzero(av == av.max()))
def sarsa(self):
for _ in range(1, self.episodes+1):
state = env.reset()
action = self.policy(state)
done = False
while not done:
next_state, reward, done, _ = env.step(action)
next_action = self.policy(state)
qsa = self.action_values[state[0]][state[1]][action]
next_qsa = self.action_values[next_state[0]][next_state[1]][next_action]
self.action_values[state[0]][state[1]][action] = qsa + self.alpha *(reward + self.gamma * next_qsa - qsa)
state = next_state
action = next_action
if __name__ == '__main__':
env = GridWorld(shape = np.array((5,5)), obstacles = np.array([[0,1], [1,1], [2,1], [3,1],\
[1,3],[2,3],[3,3],[4,3] ]))
sarsa = Sarsa(env, episodes = 10000, epsilon=.2)
sarsa.sarsa()
steps =0
done = False
env.reset()
while True:
steps += 1
clear()
state = env.get_agent_pos()
action = sarsa.policy(state)
state, _, done, _ = env.step(action)
env.render()
if done:
print(f"the agent reached terminal state in {steps} steps")
plot_q_table(sarsa.action_values, 5)
break
time.sleep(.5)