This example shows how to create a custom agent for your own custom reinforcement learning algorithm. Doing so allows you to leverage the following built-in functionality from the Reinforcement Learning Toolbox™ software.
In this example, you convert a custom REINFORCE training loop into a custom agent class. For more information on the REINFORCE custom train loop, see Train Reinforcement Learning Policy Using Custom Training Loop. For more information on writing custom agent classes, see Custom Agents.
Fix the random generator seed for reproducibility.
rng(0)
Create the same training environment used in the Train Reinforcement Learning Policy Using Custom Training Loop example. The environment is a cart-pole balancing environment with a discrete action space. Create the environment using the rlPredefinedEnv
function.
env = rlPredefinedEnv('CartPole-Discrete');
Extract the observation and action specifications from the environment.
obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);
Obtain the number of observations (numObs
) and actions (numAct
).
numObs = obsInfo.Dimension(1); numAct = numel(actInfo.Elements);
For more information on this environment, see Load Predefined Control System Environments.
The reinforcement learning policy in this example is a discrete-action stochastic policy. It is represented by a deep neural network that contains fullyConnectedLayer
, reluLayer
, and softmaxLayer
layers. This network outputs probabilities for each discrete action given the current observations. The softmaxLayer
ensures that the representation outputs probability values in the range [0 1] and that all probabilities sum to 1.
Create the deep neural network for the actor.
actorNetwork = [featureInputLayer(numObs,'Normalization','none','Name','state') fullyConnectedLayer(24,'Name','fc1') reluLayer('Name','relu1') fullyConnectedLayer(24,'Name','fc2') reluLayer('Name','relu2') fullyConnectedLayer(2,'Name','output') softmaxLayer('Name','actionProb')];
Create the actor representation using an rlStochasticActorRepresentation
object.
actorOpts = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1); actor = rlStochasticActorRepresentation(actorNetwork,... obsInfo,actInfo,'Observation','state',actorOpts);
To define your custom agent, first create a class that is a subclass of the rl.agent.CustomAgent
class. The custom agent class for this example is defined in CustomReinforceAgent.m
.
The CustomReinforceAgent
class has the following class definition, which indicates the agent class name and the associated abstract agent.
classdef CustomReinforceAgent < rl.agent.CustomAgent
To define your agent you must specify the following:
Agent properties
Constructor function
Critic representation that estimates the discounted long-term reward (if required for learning)
Actor representation that selects an action based on the current observation (if required for learning)
Required agent methods
Optional agent methods
In the properties
section of the class file, specify any parameters necessary for creating and training the agent.
The rl.Agent.CustomAgent
class already includes properties for the agent sample time (SampleTime
) and the action and observation specifications (ActionInfo
and ObservationInfo
, respectively).
The custom REINFORCE agent defines the following additional agent properties.
properties % Actor representation Actor % Agent options Options % Experience buffer ObservationBuffer ActionBuffer RewardBuffer end properties (Access = private) % Training utilities Counter NumObservation NumAction end
To create your custom agent, you must define a constructor function. The constructor function performs the following actions.
Defines the action and observation specifications. For more information about creating these specifications, see rlNumericSpec
and rlFiniteSetSpec
.
Sets the agent properties.
Calls the constructor of the base abstract class.
Defines the sample time (required for training in Simulink environments).
For example, the CustomREINFORCEAgent
constructor defines action and observation spaces based on the input actor representation.
function obj = CustomReinforceAgent(Actor,Options) %CUSTOMREINFORCEAGENT Construct custom agent % AGENT = CUSTOMREINFORCEAGENT(ACTOR,OPTIONS) creates custom % REINFORCE AGENT from rlStochasticActorRepresentation ACTOR % and structure OPTIONS. OPTIONS has fields: % - DiscountFactor % - MaxStepsPerEpisode % (required) Call the abstract class constructor. obj = obj@rl.agent.CustomAgent(); obj.ObservationInfo = Actor.ObservationInfo; obj.ActionInfo = Actor.ActionInfo; % (required for Simulink environment) Register sample time. % For MATLAB environment, use -1. obj.SampleTime = -1; % (optional) Register actor and agent options. Actor = setLoss(Actor,@lossFunction); obj.Actor = Actor; obj.Options = Options; % (optional) Cache the number of observations and actions. obj.NumObservation = prod(obj.ObservationInfo.Dimension); obj.NumAction = prod(obj.ActionInfo.Dimension); % (optional) Initialize buffer and counter. reset(obj); end
The constructor sets the loss function of the actor representation using a function handle to lossFunction
, which is implemented as a local function in CustomREINFORCEAgent.m
.
function loss = lossFunction(policy,lossData) % Create the action indication matrix. batchSize = lossData.batchSize; Z = repmat(lossData.actInfo.Elements',1,batchSize); actionIndicationMatrix = lossData.actionBatch(:,:) == Z; % Resize the discounted return to the size of policy. G = actionIndicationMatrix .* lossData.discountedReturn; G = reshape(G,size(policy)); % Round any policy values less than eps to eps. policy(policy < eps) = eps; % Compute the loss. loss = -sum(G .* log(policy),'all'); end
To create a custom reinforcement learning agent you must define the following implementation functions.
getActionImpl
— Evaluate agent policy and select an agent during simulation.
getActionWithExplorationImpl
— Evaluate policy and select an action with exploration during training.
learnImpl
— How the agent learns from the current experience
To call these functions in your own code, use the wrapper methods from the abstract base class. For example, to call getActionImpl
, use getAction
. The wrapper methods have the same input and output arguments as the implementation methods.
getActionImpl
FunctionThe getActionImpl
function is used to evaluate the policy of your agent and select an action when simulating the agent using the sim
function. This function must have the following signature, where obj
is the agent object, Observation
is the current observation, and Action
is the selected action.
function Action = getActionImpl(obj,Observation)
For the custom REINFORCE agent, you select an action by calling the getAction
function for the actor representation. The discrete rlStochasticActorRepresentation
generates a discrete distribution from an observation and samples an action from this distribution.
function Action = getActionImpl(obj,Observation) % Compute an action using the policy given the current % observation. Action = getAction(obj.Actor,Observation); end
getActionWithExplorationImpl
FunctionThe getActionWithExplorationImpl
function selects an action using the exploration model of your agent when training the agent using the train
function. Using this function you can implement exploration techniques such as epsilon-greedy exploration or the addition of Gaussian noise. This function must have the following signature, where obj
is the agent object, Observation
is the current observation, and Action
is the selected action.
function Action = getActionWithExplorationImpl(obj,Observation)
For the custom REINFORCE agent, the getActionWithExplorationImpl
function is the same as getActionImpl
. By default, stochastic actors always explore, that is, they always select an action based on a probability distribution.
function Action = getActionWithExplorationImpl(obj,Observation) % Compute an action using the exploration policy given the % current observation. % REINFORCE: Stochastic actors always explore by default % (sample from a probability distribution) Action = getAction(obj.Actor,Observation); end
learnImpl
FunctionThe learnImpl
function defines how the agent learns from the current experience. This function implements the custom learning algorithm of your agent by updating the policy parameters and selecting an action with exploration for the next state. This function must have the following signature, where obj
is the agent object, Experience
is the current agent experience, and Action
is the selected action.
function Action = learnImpl(obj,Experience)
The agent experience is the cell array Experience = {state,action,reward,nextstate,isdone}
. Here:
state
is the current observation.
action
is the current action. This is different from the output argument Action
, which is an action for the next state.
reward
is the current reward.
nextState
is the next observation.
isDone
is a logical flag indicating that the training episode is complete.
For the custom REINFORCE agent, replicate steps 2 through 7 of the custom training loop in Train Reinforcement Learning Policy Using Custom Training Loop. You omit steps 1, 8, and 9 since you will use the built-in train
function to train your agent.
function Action = learnImpl(obj,Experience) % Define how the agent learns from an Experience, which is a % cell array with the following format. % Experience = {observation,action,reward,nextObservation,isDone} % Reset buffer at the beginning of the episode. if obj.Counter < 2 resetBuffer(obj); end % Extract data from experience. Obs = Experience{1}; Action = Experience{2}; Reward = Experience{3}; NextObs = Experience{4}; IsDone = Experience{5}; % Save data to buffer. obj.ObservationBuffer(:,:,obj.Counter) = Obs{1}; obj.ActionBuffer(:,:,obj.Counter) = Action{1}; obj.RewardBuffer(:,obj.Counter) = Reward; if ~IsDone % Choose an action for the next state. Action = getActionWithExplorationImpl(obj, NextObs); obj.Counter = obj.Counter + 1; else % Learn from episodic data. % Collect data from the buffer. BatchSize = min(obj.Counter,obj.Options.MaxStepsPerEpisode); ObservationBatch = obj.ObservationBuffer(:,:,1:BatchSize); ActionBatch = obj.ActionBuffer(:,:,1:BatchSize); RewardBatch = obj.RewardBuffer(:,1:BatchSize); % Compute the discounted future reward. DiscountedReturn = zeros(1,BatchSize); for t = 1:BatchSize G = 0; for k = t:BatchSize G = G + obj.Options.DiscountFactor ^ (k-t) * RewardBatch(k); end DiscountedReturn(t) = G; end % Organize data to pass to the loss function. LossData.batchSize = BatchSize; LossData.actInfo = obj.ActionInfo; LossData.actionBatch = ActionBatch; LossData.discountedReturn = DiscountedReturn; % Compute the gradient of the loss with respect to the % actor parameters. ActorGradient = gradient(obj.Actor,'loss-parameters',... {ObservationBatch},LossData); % Update the actor parameters using the computed gradients. obj.Actor = optimize(obj.Actor,ActorGradient); % Reset the counter. obj.Counter = 1; end end
Optionally, you can define how your agent is reset at the start of training by specifying a resetImpl
function with the following function signature, where obj
is the agent object.
function resetImpl(obj)
Using this function, you can set the agent into a know or random condition before training.
function resetImpl(obj) % (Optional) Define how the agent is reset before training/ resetBuffer(obj); obj.Counter = 1; end
Also, you can define any other helper functions in your custom agent class as required. For example, the custom REINFORCE agent defines a resetBuffer
function for reinitializing the experience buffer at the beginning of each training episode.
function resetBuffer(obj) % Reinitialize all experience buffers. obj.ObservationBuffer = zeros(obj.NumObservation,1,obj.Options.MaxStepsPerEpisode); obj.ActionBuffer = zeros(obj.NumAction,1,obj.Options.MaxStepsPerEpisode); obj.RewardBuffer = zeros(1,obj.Options.MaxStepsPerEpisode); end
Once you have defined your custom agent class, create an instance of it in the MATLAB workspace. To create the custom REINFORCE agent, first specify the agent options.
options.MaxStepsPerEpisode = 250; options.DiscountFactor = 0.995;
Then, using the options and the previously defined actor representation, call the custom agent constructor function.
agent = CustomReinforceAgent(actor,options);
Configure the training to use the following options.
Set up the training to last at most 5000 episodes, with each episode lasting at most 250 steps.
Terminate the training after the maximum number of episodes is reached or when the average reward across 100 episodes reaches a value of 240.
For more information, see rlTrainingOptions
.
numEpisodes = 5000; aveWindowSize = 100; trainingTerminationValue = 240; trainOpts = rlTrainingOptions(... 'MaxEpisodes',numEpisodes,... 'MaxStepsPerEpisode',options.MaxStepsPerEpisode,... 'ScoreAveragingWindowLength',aveWindowSize,... 'StopTrainingValue',trainingTerminationValue);
Train the agent using the train
function. Training this agent is a computationally intensive process that takes several minutes to complete. To save time while running this example, load a pretrained agent by setting doTraining
to false
. To train the agent yourself, set doTraining
to true
.
doTraining = false; if doTraining % Train the agent. trainStats = train(agent,env,trainOpts); else % Load pretrained agent for the example. load('CustomReinforce.mat','agent'); end
Enable the environment visualization, which is updated each time the environment step
function is called.
plot(env)
To validate the performance of the trained agent, simulate it within the cart-pole environment. For more information on agent simulation, see rlSimulationOptions
and sim
.
simOpts = rlSimulationOptions('MaxSteps',options.MaxStepsPerEpisode);
experience = sim(env,agent,simOpts);