Agent
We provide a base class AgentBase
with some utils functions to extract the desired
state from the observation. You can inherit the base class and implement your onw method
in the air_hockey_agent/agent_builder.py
file. A Dummy Agent example can be found
in Dummy Agent.
Load and Save Agent
We also provide a simple and effective way of save and load your agent. We extend the
Dummy Agent
example and set different type of variables. You can add these variables
into saving list by calling self.__add_save_attr
function.
The available methods are:
primitive, to store any primitive type. This includes lists and dictionaries of primitive values.
numpy, to store NumPy arrays.
torch, to store any torch object.
pickle, to store any Python object that cannot be stored with the above methods.
json, can be used if you need a textual output version, that is easy to read.
none, add the attributes, you can assign the values to the attribute later.
import numpy as np
import torch
from air_hockey_challenge.framework import AgentBase, AirHockeyChallengeWrapper
def build_agent(env_info, **kwargs):
"""
Function where an Agent that controls the environments should be returned.
The Agent should inherit from the mushroom_rl Agent base env.
Args:
env_info (dict): The environment information
kwargs (any): Additionally setting from agent_config.yml
Returns:
(AgentBase) An instance of the Agent
"""
return DummyAgent(env_info, **kwargs)
class DummyAgent(AgentBase):
def __init__(self, env_info, value, **kwargs):
super().__init__(env_info, **kwargs)
self.new_start = True
self.hold_position = None
self.primitive_variable = value # Primitive python variable
self.numpy_vector = np.array([1, 2, 3]) * value # Numpy array
self.list_variable = [1, 'list', [2, 3]] # Numpy array
# Dictionary
self.dictionary = dict(some='random', keywords=2, fill='the dictionary')
# Building a torch object
data_array = np.ones(3) * value
data_tensor = torch.from_numpy(data_array)
self.torch_object = torch.nn.Parameter(data_tensor)
# A non serializable object
self.object_instance = object()
# A variable that is not important e.g. a buffer
self.not_important = np.zeros(10000)
# Here we specify how to save each component
self._add_save_attr(
primitive_variable='primitive',
numpy_vector='numpy',
list_variable='primitive',
dictionary='pickle',
torch_object='torch',
object_instance='none',
# The '!' is to specify that we save the variable only if full_save is True
not_important='numpy!',
)
def reset(self):
self.new_start = True
self.hold_position = None
def draw_action(self, observation):
if self.new_start:
self.new_start = False
self.hold_position = self.get_joint_pos(observation)
velocity = np.zeros_like(self.hold_position)
action = np.vstack([self.hold_position, velocity])
return action
if __name__ == '__main__':
env = AirHockeyChallengeWrapper("3dof-hit")
# Construct Agent
args = {'value': 1.1}
agent_save = build_agent(env.env_info, **args)
print("######################################################")
print("Save Agent Variables")
print("######################################################")
print("agent_save.primitive_variable: ", agent_save.primitive_variable)
print("agent_save.numpy_vector: ", agent_save.numpy_vector)
print("agent_save.list_variable: ", agent_save.list_variable)
print("agent_save.dictionary: ", agent_save.dictionary)
print("agent_save.torch_object: ", agent_save.torch_object)
# The not_important variable will not be saved unless the full_save is set True
agent_save.save("agent.msh", full_save=False)
agent_load = DummyAgent.load_agent("agent.msh", env.env_info)
print("######################################################")
print("Load the Agent")
print("######################################################")
print("agent_load.primitive_variable: ", agent_load.primitive_variable)
print("agent_load.numpy_vector: ", agent_load.numpy_vector)
print("agent_load.list_variable: ", agent_load.list_variable)
print("agent_load.dictionary: ", agent_load.dictionary)
print("agent_load.torch_object: ", agent_load.torch_object)
print("agent_load.object_instance: ", agent_load.object_instance)
print("------------------------------------------------------")
print("These variable will not be saved while full_save is False")
print("agent_load.not_important: ", agent_load.not_important)
print("------------------------------------------------------")
print("These variable will be parsed from env_info:")
print("agent_load.env_info.keys()s: ", agent_load.env_info.keys())
print("agent_load.agent_id: ", agent_load.agent_id)
print("agent_load.robot_model: ", agent_load.robot_model)
print("agent_load.robot_data: ", agent_load.robot_data)
AgentBase
air_hockey_challenge.framework.agent_base