To implement your own custom reinforcement learning algorithms, you can create a custom agent by creating a subclass of a custom agent class. You can then train and simulate this agent in MATLAB® and Simulink® environments. For more information about creating classes in MATLAB, see User-Defined Classes.
To define your custom agent, first create a class that is a subclass of the
rl.agent.CustomAgent
class. As an example, this topic describes the
custom LQR agent trained in Train Custom LQR Agent. As a starting point for your
own agent, you can open and modify this custom agent class. To add the example files to the
MATLAB path and open the file, at the MATLAB command line, type the following code.
addpath(fullfile(matlabroot,'examples','rl','main')); edit LQRCustomAgent.m
After saving the class to your own working folder, you can remove the example files from the path.
rmpath(fullfile(matlabroot,'examples','rl','main'));
This class has the following class definition, which indicates the agent class name and the associated abstract agent.
classdef LQRCustomAgent < rl.agent.CustomAgent
To define your agent, you must specify the following:
Agent properties
Constructor function
Critic representation that estimates the discounted long-term reward (if required for learning)
Actor representation that selects an action based on the current observation (if required for learning)
Required agent methods
Optional agent methods
In the properties
section of the class file, specify any parameters
necessary for creating and training the agent. These parameters can include:
Discount factor for discounting future rewards
Configuration parameters for exploration models, such as noise models or epsilon-greedy exploration
Experience buffers for using replay memory
Mini-batch sizes for sampling from the experience buffer
Number of steps to look ahead during training
For more information on potential agent properties, see the option objects for the built-in Reinforcement Learning Toolbox™ agents.
The rl.Agent.CustomAgent
class already includes properties for the
agent sample time (SampleTime
) and the action and observation
specifications (ActionInfo
and ObservationInfo
,
respectively).
The custom LQR agent defines the following agent properties.
properties % Q Q % R R % Feedback gain K % Discount factor Gamma = 0.95 % Critic Critic % Buffer for K KBuffer % Number of updates for K KUpdate = 1 % Number for estimator update EstimateNum = 10 end properties (Access = private) Counter = 1 YBuffer HBuffer end
To create your custom agent, you must define a constructor function that:
Defines the action and observation specifications. For more information about
creating these specifications, see rlNumericSpec
and
rlFiniteSetSpec
.
Creates actor and critic representations as required by your training algorithm. For more information, see Create Policy and Value Function Representations.
Configures agent properties.
Calls the constructor of the base abstract class.
For example, the LQRCustomAgent
constructor defines continuous action
and observation spaces and creates a critic representation. The
createCritic
function is an optional helper function that defines the
critic representation.
function obj = LQRCustomAgent(Q,R,InitialK) % Check the number of input arguments narginchk(3,3); % Call the abstract class constructor obj = obj@rl.agent.CustomAgent(); % Set the Q and R matrices obj.Q = Q; obj.R = R; % Define the observation and action spaces obj.ObservationInfo = rlNumericSpec([size(Q,1),1]); obj.ActionInfo = rlNumericSpec([size(R,1),1]); % Create the critic representation obj.Critic = createCritic(obj); % Initialize the gain matrix obj.K = InitialK; % Initialize the experience buffers obj.YBuffer = zeros(obj.EstimateNum,1); num = size(Q,1) + size(R,1); obj.HBuffer = zeros(obj.EstimateNum,0.5*num*(num+1)); obj.KBuffer = cell(1,1000); obj.KBuffer{1} = obj.K; end end
If your learning algorithm uses a critic representation to estimate the long-term reward, an actor for selecting an action, or both, you must add these as agent properties. You must then create these representations when you create your agent; that is, in the constructor function. For more information on creating actors and critics, see Create Policy and Value Function Representations.
For example, the custom LQR agent uses a critic representation, stored in its
Critic
property, and no actor. The critic creation is implemented in
the createCritic
helper function, which is called from the
LQRCustomAgent
constructor.
function critic = createCritic(obj) nQ = size(obj.Q,1); nR = size(obj.R,1); n = nQ+nR; w0 = 0.1*ones(0.5*(n+1)*n,1); critic = rlQValueRepresentation({@(x,u) computeQuadraticBasis(x,u,n),w0},... getObservationInfo(obj),getActionInfo(obj)); critic.Options.GradientThreshold = 1; end
In this case, the critic is an rlQValueRepresentation
object. To create such a representation, you must
specify the handle to a custom basis function, in this case the
computeQuadraticBasis
function. For more information on this critic
representation, see Train Custom LQR Agent.
To create a custom reinforcement learning agent you must define the following
implementation functions. To call these functions in your own code, use the wrapper methods
from the abstract base class. For example, to call getActionImpl
, use
getAction
. The wrapper methods have the same input and output arguments
as the implementation methods.
Function | Description |
---|---|
getActionImpl | Selects an action by evaluating the agent policy for a given observation |
getActionWithExplorationImpl | Selects an action using the exploration model of the agent |
learnImpl | Learns from the current experiences and returns an action with exploration |
Within your implementation functions, to evaluate your actor and critic representations,
you can use the getValue
, getAction
, and
getMaxQValue
functions.
To evaluate an rlValueRepresentation
critic with only
observation input signals, obtain the state value function V
using
the following syntax.
V = getValue(Critic,Observation);
To evaluate an rlQValueRepresentation
critic with both
observation and action input signals, obtain the state-action value function
Q
using the following syntax.
Q = getValue(Critic,[Observation,Action]);
To evaluate an rlQValueRepresentation
critic with only
observation input signals, obtain the state-action value function Q
for all possible discrete actions using the following syntax.
Q = getValue(Critic,Observation);
A discrete action space rlQValueRepresentation
critic, obtain
the maximum Q state-action value function Q
for all possible discrete
actions using the following syntax.
[MaxQ,MaxActionIndex] = getMaxQValue(Critic,Observation);
To evaluate an actor representation
(rlStochasticActorRepresentation
or
rlDeterministicActorRepresentation
), obtain the action
A
using the following syntax.
A = getAction(Actor,Observation);
For each of these cases, if your actor or critic network uses a recurrent neural network, the functions can also return the current values of the network state after obtaining the corresponding network output.
getActionImpl
FunctionThe getActionImpl
function is evaluates the policy of your agent
and selects an action. This function must have the following signature, where
obj
is the agent object, Observation
is the
current observation, and action
is the selected action.
function action = getActionImpl(obj,Observation)
For the custom LQR agent, you select an action by applying the u=-Kx control law.
function action = getActionImpl(obj,Observation) % Given the current state of the system, return an action action = -obj.K*Observation{:}; end
getActionWithExplorationImpl
FunctionThe getActionWithExplorationImpl
function selects an action using
the exploration model of your agent. Using this function you can implement algorithms such
as epsilon-greedy exploration. This function must have the following signature, where
obj
is the agent object, Observation
is the
current observation, and action
is the selected action.
function action = getActionWithExplorationImpl(obj,Observation)
For the custom LQR agent, the getActionWithExplorationImpl
function
adds random white noise to an action selected using the current agent policy.
function action = getActionWithExplorationImpl(obj,Observation) % Given the current observation, select an action action = getAction(obj,Observation); % Add random noise to the action num = size(obj.R,1); action = action + 0.1*randn(num,1); end
learnImpl
FunctionThe learnImpl
function defines how the agent learns from the
current experience. This function implements the custom learning algorithm of your agent
by updating the policy parameters and selecting an action with exploration. This function
must have the following signature, where obj
is the agent object,
exp
is the current agent experience, and action
is
the selected action.
function action = learnImpl(obj,exp)
The agent experience is the cell array exp =
{state,action,reward,nextstate,isdone}
.
state
is the current observation.
action
is the current action.
reward
is the current reward.
nextState
is the next observation.
isDone
is a logical flag indicating that the training episode
is complete.
For the custom LQR agent, the critic parameters are updated every N
steps.
function action = learnImpl(obj,exp) % Parse the experience input x = exp{1}{1}; u = exp{2}{1}; dx = exp{4}{1}; y = (x'*obj.Q*x + u'*obj.R*u); num = size(obj.Q,1) + size(obj.R,1); % Wait N steps before updating the critic parameters N = obj.EstimateNum; h1 = computeQuadraticBasis(x,u,num); h2 = computeQuadraticBasis(dx,-obj.K*dx,num); H = h1 - obj.Gamma* h2; if obj.Counter<=N obj.YBuffer(obj.Counter) = y; obj.HBuffer(obj.Counter,:) = H; obj.Counter = obj.Counter + 1; else % Update the critic parameters based on the batch of % experiences H_buf = obj.HBuffer; y_buf = obj.YBuffer; theta = (H_buf'*H_buf)\H_buf'*y_buf; obj.Critic = setLearnableParameters(obj.Critic,{theta}); % Derive a new gain matrix based on the new critic parameters obj.K = getNewK(obj); % Reset the experience buffers obj.Counter = 1; obj.YBuffer = zeros(N,1); obj.HBuffer = zeros(N,0.5*num*(num+1)); obj.KUpdate = obj.KUpdate + 1; obj.KBuffer{obj.KUpdate} = obj.K; end % Find and return an action with exploration action = getActionWithExploration(obj,exp{4}); end
Optionally, you can define how your agent is reset at the start of training by
specifying a resetImpl
function with the following function signature,
where obj
is the agent object. Using this function, you can set the agent
into a known or random condition before training.
function resetImpl(ob)
Also, you can define any other helper functions in your custom agent class as required.
For example, the custom LQR agent defines a createCritic
function for
creating the critic representation and a getNewK
function that derives
the feedback gain matrix from the trained critic parameters.
After you define your custom agent class, create an instance of it in the MATLAB workspace. For example, to create the custom LQR agent, define the
Q
, R
, and InitialK
values and
call the constructor function.
Q = [10,3,1;3,5,4;1,4,9]; R = 0.5*eye(3); K0 = place(A,B,[0.4,0.8,0.5]); agent = LQRCustomAgent(Q,R,K0);
After validating the environment object, you can use it to train a reinforcement learning agent. For an example that trains the custom LQR agent, see Train Custom LQR Agent.