Source: Reinforcement Learning (DQN) Tutorial - PyTorch
I’m revisiting DQN using the CartPole-v1 gym environment
- Possible actions: ← or →
- State: Position, velocity
- Goal: Keep the pole balanced
- Rewards: +1 for every incremental timestep. Env terminates if the pole falls over too far or the cart moves >2.4 units away from the center
Using Experience Replay Memory (ERM) for training the DQN
- [b] Move ERM to its own note
- Detailed reading on ERM
ERM
ERM stores the transitions observed by the agent. This can be reused at any point of time later
Random sampling from the ERM ensures that the transitions are decorrelated. This improves stability and training
Transition: An agent in an environment of a given state performs an action that moves the environment to the next state, producing a reward
A cyclic buffer deque[]
is used here because we want it to store the most recent state transitions. If it is already full, then the next recent transition should be stored in the head of the deque, discarding the oldest element in the container i.e. the tail end of the deque
- [b] Move DQN to own note
is the cumulative discounted return, is the time step with being the starting time step. is the discount rate and r is the reward per transition/time step. This is an infinite horizon equation as the cumulative sum goes on to infinity
Role of the discount rate
Intuitively looking at the equation, the discount coefficient would expand to
And is between 0 and 1. How does it work? Higher terms diminish because becomes exponentially smaller as time goes on. This ensures that rewards anticipated in the uncertain far future are given lesser importance and are gradually discarded. This gives precedence to rewards obtained in the earlier stages of the training process, improving the speed and performance of the algorithm. In short, it ensures convergence of returns
The fundamentals of Q-Learning is that we generate a table of Q-values for each state-action pair (s, a). In other words, Q: s x a → R. The higher the Q-value, the better. The objective is to identify an optimal policy that produces the best Q-value for (s, a). However, in reality, we don’t know what this is, so we’ll have to figure it out. The part where DQN enters is by the usage of a neural network. NNs are good function approximators, which make them good for calculating the optimal policy. Q(s, a) obeys the Bellman equation and this is used for the training update of the DQN. As with any NN, the goal is to backpropagate the loss and update the weights to minimize it. The loss here is the difference between Q(s, a) and the Bellman equation is or the temporal difference (TD) error. The TD error is minimized using the Huber loss formula and is calculated using a batch of transitions drawn from the ERM
Huber Loss
The Huber Loss is a loss function that acts like the mean squared error when the TD error is small and the mean absolute error when the TD error is large. This makes it more robust to noise and outliers
n_observations
is the size of the state space of the environment. For example, in this case, there are 4 variables used to track the state. That makes n_observations
= 4. This forms the input dimensions of the DQN. As for why 128 was chosen, this is purely a matter of trial-and-error. I would have to mess around and see which works optimally although I’m sure there is some logic to choosing such numbers
The Epsilon greedy policy introduces a factor that selects an action either from the model or from a uniform sample. As this process, the probability factor decays at a rate of EPS_DECAY
In DQN, we have a policy network and a target network. The function of the target network is to enhance the stability of the DQN. The weights of the DQN are copied over to the target network periodically albeit slowly. The hyperparameter that controls the soft update of the target network is
-
Can the
namedtuple
name be different from its typename argument? -
What is the meaning of
class ClassName(object):
? What does theobject
argument do? -
This is a silly question but what if was negative?
-
Understanding the discount factor - CrossValidated (Stack Exchange) - Further reading on