SARSA

SARSA is a reinforcement learning algorithm that is used to learn the optimal action-selection policy for a given Markov Decision Process (MDP). It is an "on-policy" learning algorithm, which means that it learns the optimal policy while following it.


The basic idea behind SARSA is to learn an estimate of the state-action value function, or Q-function, for a given MDP. The Q-function is a mathematical function that takes in a state and an action, and returns the expected long-term reward for taking that action in that state. The goal of SARSA is to find the Q-function that will maximize the expected long-term reward for a given MDP.


To do this, SARSA uses an iterative update rule to improve the estimate of the Q-function over time. At each iteration, the algorithm takes an action in the current state, observes the resulting reward and next state, and then updates the Q-function estimate using the observed reward and the expected long-term reward for the next action chosen using the current Q-function estimate. This update rule is applied repeatedly until the Q-function converges to the optimal Q-function for the MDP.


One key advantage of SARSA is that it is an "on-policy" learning algorithm, which means that it learns the optimal policy while following it. This can make it more stable and easier to converge to the optimal policy compared to "off-policy" learning algorithms, which can sometimes experience instability or divergence. Additionally, SARSA is a model-based algorithm, which means that it requires a model of the environment in order to learn the optimal policy. This can make it more efficient in environments where the dynamics are known or can be accurately modeled.


Overall, SARSA is a powerful and widely-used reinforcement learning algorithm that is well-suited to a variety of environments and tasks. It is simple to implement and has been applied successfully to a wide range of real-world problems, including robot control, game playing, and autonomous vehicles.


Utility Functions
 import numpy as np
import os

def plot_state_value_table(table, cols):
    for idx, state in enumerate(table):
        print(f"{table[state]:.2f}", end="\t")
        if (idx+1) % cols == 0:
            print()



def plot_q_table(table, cols):
    for r in range(5):
        for c in range(5):
            q_values = np.round(table[r][c],2)
            print(q_values, end="\t")
        print()

# clear the terminal
# TODO: this function run on linux only. 
# for windows the clear must be replaced with cls
clear = lambda: os.system('clear')
 
The Grid World Environment
 import numpy as np

class Action(): 
    def __init__(self):
        self.action_space = {'U': (-1, 0), 'D': (1, 0),\
                                'L': (0, -1), 'R': (0, 1)}
        self.possible_actions = list(self.action_space.keys())
        self.action_n = len(self.possible_actions)
        self.action_idxs = range(self.action_n)

    def get_action_by_idx(self, idx):
        return self.action_space[self.possible_actions[idx]]

    def get_action_by_key(self, key):
        return self.action_space[key]

class GridWorld():
    def __init__(self, shape=(3, 3), obstacles=[], terminal=None,\
                    agent_pos=np.array([0, 0])):
        self.action  = Action()
        self.shape = shape
        self.rows = shape[0]
        self.cols = shape[1]
        self.obstacles = obstacles
        self.agent_pos = agent_pos
        self.agent_init_pos = agent_pos
        if terminal is None:
            self.terminal_state = (self.rows-1, self.cols-1)
        else:
            self.terminal_state = terminal
        self.done = False

        self.add_state_action = lambda state, action : \
            tuple(np.array(state) + np.array(self.action.get_action_by_idx(action)))
        self.is_same_state = lambda s1, s2 : s1[0] == s2[0] and s1[1] == s2[1]

    def is_obstacle(self, state):
        for s in self.obstacles:
            if self.is_same_state(s, state):  
                return True
        return False

    def is_terminal(self, s):
        return self.is_same_state(s, self.terminal_state)  

    def is_edge(self, state):
        if state[0] < 0 or state[0] > self.rows -1 \
            or state[1] < 0 or state[1] > self.cols -1:
            return True
        return False

    def set_agent_pos(self, state):
        self.agent_pos = state

    def get_agent_pos(self):
        return self.agent_pos
    
    def step(self, action):
        # agent location is current agent location + action
        state = self.get_agent_pos()
        tmp_state = self.add_state_action(state, action) 
        #print(f"tmp_state{tmp_state}")
        if self.is_obstacle(tmp_state):
            # print("OBSTACLES")
            pass
        elif self.is_terminal(tmp_state):
            self.set_agent_pos(tmp_state)
            #print(f"terminal_state:{tmp_state}")
            self.done = True
            #print("Done")
        elif self.is_edge(tmp_state):
            # print("Edge")
            pass
        else:        
            self.set_agent_pos(tmp_state)
        
        reward = -1 if not self.done else 0

        return self.get_agent_pos(), reward, self.done, None

    def simulated_step(self, state, action):
        if self.is_terminal(state):
            #state = next_state
            return state, 0, True, None

        next_state = self.add_state_action(state, action) 
        if self.is_obstacle(next_state) or self.is_edge(next_state):
            pass
        else:        
            state = next_state
        return state, -1, False, None

    def reset(self):
        self.set_agent_pos(self.agent_init_pos)
        self.done = False
        return self.agent_init_pos

    def render(self):
        for r in range(self.rows):
            for c in range(self.cols):
                state = np.array((r, c))
                if self.is_terminal(state):
                    if self.done:
                        print('[O]', end="\t")
                    else:
                        print('[]', end="\t")
                elif  self.is_same_state(self.get_agent_pos(), state): 
                    print('O', end="\t")
                elif self.is_obstacle(state):
                    print('X', end="\t")
                else:
                    print('-', end="\t")
            print()


if __name__ == '__main__':
    env = GridWorld(shape=(5, 5), obstacles=((0, 1), (1, 1)))
    env.render()
    env.step(0) # 0 -> up
    env.step(1) # 1 -> down
    env.step(2) # 2 -> left
    env.render()
 
The SARSA Agent
 import numpy as np
import time

from env import GridWorld 
from utils import *

class Sarsa():
    def __init__(self, env, episodes=10000, epsilon=.2, alpha=.1, gamma=.99):
        self.action_values = np.zeros((env.rows, env.cols, env.action.action_n))
        self.episodes =  episodes
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

    def policy(self, state):
        if np.random.random() < self.epsilon:
            return np.random.choice(env.action.action_idxs) 
        else:
            av = self.action_values [state[0]][state[1]]
            return np.random.choice(np.flatnonzero(av == av.max()))

    def sarsa(self):
        for _ in range(1, self.episodes+1):
            state = env.reset()
            action = self.policy(state)
            done = False

            while not done:
                next_state, reward, done, _ = env.step(action)
                next_action = self.policy(state)

                qsa = self.action_values[state[0]][state[1]][action] 
                next_qsa = self.action_values[next_state[0]][next_state[1]][next_action] 
                self.action_values[state[0]][state[1]][action] = qsa + self.alpha *(reward + self.gamma * next_qsa - qsa)

                state = next_state
                action = next_action


if __name__ == '__main__':
    env = GridWorld(shape = np.array((5,5)), obstacles = np.array([[0,1], [1,1], [2,1], [3,1],\
                    [1,3],[2,3],[3,3],[4,3] ]))
    
    sarsa = Sarsa(env, episodes = 10000, epsilon=.2)
    sarsa.sarsa()
    steps =0
    done = False
    env.reset()
    while True: 
        steps += 1
        clear()
        state = env.get_agent_pos()
        action = sarsa.policy(state)
        state, _, done, _ = env.step(action)
        env.render()
        if done:
            print(f"the agent reached terminal state in {steps} steps")
            plot_q_table(sarsa.action_values, 5)
            break

        time.sleep(.5)