What is callback
You can define some procedure which you want to execute in the middle of training through callback object.
(ex. record performance of value function in each 1000 iteration of training.)
How to set callback for training is createing list of callbacks and pass it to algorithm.run_gpi(nb_iteration, callbacks=None)
.
(You can also pass single callback object (not list) as callbacks arguments)
Callback methods
The base class of all callback objects is kyoka.callback.BaseCallback
.
All callback object must inherit this class and override callback methods as you want.
kyoka.callback.BaseCallback
class has 5 callback methods which callback object can override.
before_gpi_start(self, task, value_function)
- called when
algorithm.run_gpi
is called
- called when
before_update(self, iteration_count, task, value_function)
- called before
iteration_count
of episode is played in training
- called before
after_update(self, iteration_count, task, value_function)
- called after
iteration_count
of episode is played in training
- called after
after_gpi_finish(self, task, value_function)
- called when training finishes
interrupt_gpi(self, iteration_count, task, value_function)
- if you return
True
training finishes even if it doesn't reach maximum iteration count
- if you return
In default, above 4 methods are implemented as empty method and interrupt_gpi
just returns False
.
So callback object don't need to override all methods if nothing to do.
Logging methods
kyoka.callback.BaseCallback
class also have utility method log(message)
.
If you want to log something on console, we recommend you to use this method instead of print(message)
.
log(message)
prints passed message with tag like below.
>>> callback = WatchIterationCount(5000)
>>> callback.log("Start GPI iteration for 5000 times")
[WatchIterationCount] Start GPI iteration for 5000 times
We use class name of callback as tag in default. But you can easily customize it by overriding define_log_tag
method.
class WatchIterationCount(BaseCallback):
# some codes...
def define_log_tag(self):
return "Progress"
>>> callback = WatchIterationCount(5000)
>>> callback.log("Start GPI iteration for 5000 times")
[Progress] Start GPI iteration for 5000 times
Here is the sample custom callback to record value of initial state in every 1000 iteration.
import csv
from kyoka.policy import choose_best_action
class InitialStateValueRecorder(BaseCallback):
def __init__(self, score_file_path):
self.score_file_path = score_file_path
self.score_holder = []
def before_gpi_start(self, task, value_function):
value = self._predict_value_of_initial_state(task, value_function)
self.log("Value of initial state is [ %s ]" % value)
self.score_holder.append(value)
def after_update(self, iteration_count, task, value_function):
value = self._predict_value_of_initial_state(task, value_function)
self.log("Value of initial state is [ %s ]" % value)
self.score_holder.append(value)
def after_gpi_finish(self, task, value_function):
with open(self.score_file_path, "wb") as f:
writer = csv.writer(f, lineterminator="\n")
writer.writerow(self.score_holder)
self.log("Score is saved on [ %s ]" % self.score_file_path)
def _predict_value_of_initial_state(self, task, value_function):
state = task.generate_initial_state()
action = choose_best_action(task, value_function, state)
return value_function.predict_value(state, action)