add Every-Visit Monte-Carlo

This commit is contained in:
Jan Löwenstrom 2020-04-02 15:56:11 +02:00
parent 740289ee2b
commit 9a3452ff9c
4 changed files with 52 additions and 48 deletions

View File

@ -5,7 +5,6 @@ import core.Environment;
import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
import core.policy.EpsilonGreedyPolicy;
import example.DinoSampling;
import lombok.Getter;
import lombok.Setter;
@ -125,6 +124,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void startLearning(){
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
timestampCurrentEpisode = 0;
nextEpisode();

View File

@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used.
*/
public enum Method {
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL
}

View File

@ -3,52 +3,40 @@ package core.algo.mc;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.io.*;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.*;
/**
* TODO: Major problem:
* StateActionPairs are only unique accounting for their position in the episode.
* For example:
* <p>
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
* image the agent does not collect the food and does not drop it onto start, the agent will receive
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
* <p>
* BUT image moving left from the starting position will have no impact on the state because
* the agent ran into a wall. The known world stays the same.
* Taking an action after that will have the exact same state but a different action
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
* timestamp, summing up all remaining rewards will be -9...
* <p>
* How to encounter this problem?
*
* Includes both variants of Monte-Carlo methods
* Default method is First-Visit.
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
* @param <A>
*/
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
private boolean isEveryVisit;
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
super(environment, actionSpace, discountFactor, delay);
isEveryVisit = useEveryVisit;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
returnCount = new HashMap<>();
}
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
this(environment, actionSpace, discountFactor, epsilon, delay, false);
}
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@ -89,35 +77,47 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();
int firstOccurrenceIndex = 0;
for(StepResult<A> sr : episode) {
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
Pair<State, A> pair = new ImmutablePair<>(sr.getState(), sr.getAction());
if(!stateActionPairs.containsKey(pair)) {
List<Integer> l = new ArrayList<>();
l.add(firstOccurrenceIndex);
stateActionPairs.put(pair, l);
}
/*
This is the only difference between First-Visit and Every-Visit.
When First-Visit is selected, only the first index of the occurrence is put into the list.
When Every-Visit is selected, every following occurrence is saved
into the list as well.
*/
else if(isEveryVisit) {
stateActionPairs.get(pair).add(firstOccurrenceIndex);
}
++firstOccurrenceIndex;
}
//System.out.println("stateActionPairs " + stateActionPairs.size());
for(Pair<State, A> stateActionPair : stateActionPairs) {
int firstOccurenceIndex = 0;
// find first occurance of state action pair
for(StepResult<A> sr : episode) {
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
break;
for(Map.Entry<Pair<State, A>, List<Integer>> entry : stateActionPairs.entrySet()) {
Pair<State, A> stateActionPair = entry.getKey();
List<Integer> firstOccurrences = entry.getValue();
for(Integer firstOccurrencesIdx : firstOccurrences) {
double G = 0;
for(int l = firstOccurrencesIdx; l < episode.size(); ++l) {
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx));
}
firstOccurenceIndex++;
// slick trick to add G to the entry.
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
double G = 0;
for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
}
// slick trick to add G to the entry.
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
}
@Override
public void save(ObjectOutputStream oos) throws IOException {
super.save(oos);

View File

@ -51,7 +51,11 @@ public class RLController<A extends Enum> implements LearningListener {
case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case SARSA_EPISODIC:
case MC_CONTROL_EVERY_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
break;
case SARSA_ON_POLICY_CONTROL:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
case Q_LEARNING_OFF_POLICY_CONTROL: