Sarsa - on-policy TD learning method
Sarsa method updates value of state-action pair Q(s,a)
in following way.
s : current state
a : action to take at state s choosed by policy PI
r : reward by transition (s, a)
s' : next state after took action a at s
a' : action to take at state s' choosed by policy PI
Q(s,a) = Q(s,a) + alpha [ r + gamma * Q(s', a') - Q(s, a) ]
Update is done with variable s,a,r,s',a'
. So this algorithm is named as Sarsa.
This method is also called as on-policy TD learning method. The keyword is on-policy.
on-policy means that Sarsa uses same policy PI to calculate a
and a'
.
The algorithm uses different policy to select a
and a'
is called off-policy method. (ex. QLearning)
Algorithm
Parameter:
a <- alpha. learning rate. [0,1].
g <- gamma. discounting factor. [0,1].
Initialize:
T <- your RL task
PI <- policy used in the algorithm
Q <- action value function
Repeat until computational budget runs out:
S <- generate initial state of task T
A <- choose action at S by following policy PI
Repeat until S is terminal state:
S' <- next state of S after taking action A
R <- reward gained by taking action A at state S
A' <- next action at S' by following policy PI
Q(S, A) <- Q(S, A) + a * [ R + g * Q(S', A') - Q(S, A)]
S, A <- S', A'
Value function
Sarsa method provides tabular and approximation type of value functions.
SarsaTabularActionValueFunction
If your task is tabular size, you can use SarsaTabularActionValueFunction
.
If you can store the value of all state-action pair on the memory(array), your task is tabular size.
SarsaTabularActionValueFunction
has 3 abstracted method to define the table size of your task.
generate_initial_table
: initialize table object and return it herefetch_value_from_table
: define how to fetch value from your tableinsert_value_into_table
: define how to insert new value into your table
If the shape of your state-action space is SxA, implementation would be like this.
class MyTabularActionValueFunction(SarsaTabularActionValueFunction):
def generate_initial_table(self):
return [[0 for j in range(A)] for i in range(S)]
def fetch_value_from_table(self, table, state, action):
return table[state][action]
def insert_value_into_table(self, table, state, action, new_value):
table[state][action] = new_value
SarsaApproxActionValueFunction
If your task is not tabular size, you use SarsaApproxActionValueFunction
.
SarsaApproxActionValueFunction
has 3 abstracted methods. You would wrap some prediction model (ex. neuralnet) in these methods.
construct_features
: transform state-action pair into feature representationapprox_predict_value
: predict value of state-action pair with prediction model you want to useapprox_backup
: update your model in supervised learning way with passed input and output pair
The implementation with some neuralnet library would be like this.
class MyApproxActionValueFunction(SarsaApproxActionValueFunction):
def setup(self):
super(MazeApproxActionValueFunction, self).setup()
self.neuralnet = build_neuralnet_in_some_way()
def construct_features(self, state, action):
feature1 = do_something(state, action)
feature2 = do_anotherthing(state, action)
return [feature1, feature2]
def approx_predict_value(self, features):
return self.neuralnet.predict(features)
def approx_backup(self, features, backup_target, alpha):
self.neuralnet.incremental_training(X=features, Y=backup_target)
Sample code to start learning
test_length = 1000
task = MyTask()
policy = EpsilonGreedyPolicy(eps=0.1)
value_func = MyTabularActionValueFunction()
algorithm = Sarsa(gamma=0.99)
algorithm.setup(task, policy, value_func)
algorithm.run_gpi(test_length)