From 9a3452ff9ceac26dc4fbfc7b2a95a4a4bd7a2e7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Thu, 2 Apr 2020 15:56:11 +0200 Subject: [PATCH] add Every-Visit Monte-Carlo --- src/main/java/core/algo/EpisodicLearning.java | 2 +- src/main/java/core/algo/Method.java | 2 +- .../MonteCarloControlFirstVisitEGreedy.java | 90 +++++++++---------- .../java/core/controller/RLController.java | 6 +- 4 files changed, 52 insertions(+), 48 deletions(-) diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index d7b6655..6e57943 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -5,7 +5,6 @@ import core.Environment; import core.LearningConfig; import core.StepResult; import core.listener.LearningListener; -import core.policy.EpsilonGreedyPolicy; import example.DinoSampling; import lombok.Getter; import lombok.Setter; @@ -125,6 +124,7 @@ public abstract class EpisodicLearning extends Learning imple private void startLearning(){ dispatchLearningStart(); while(episodesToLearn.get() > 0){ + dispatchEpisodeStart(); timestampCurrentEpisode = 0; nextEpisode(); diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java index 6372b24..99966fc 100644 --- a/src/main/java/core/algo/Method.java +++ b/src/main/java/core/algo/Method.java @@ -5,5 +5,5 @@ package core.algo; * which RL-algorithm should be used. */ public enum Method { - MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL + MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL } diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java index ff3247b..f00fa4b 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java @@ -3,52 +3,40 @@ package core.algo.mc; import core.*; import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; -import core.policy.GreedyPolicy; -import core.policy.Policy; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; -import java.io.*; -import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.*; /** - * TODO: Major problem: - * StateActionPairs are only unique accounting for their position in the episode. - * For example: - *

- * startingState -> MOVE_LEFT : very first state action in the episode i = 1 - * image the agent does not collect the food and does not drop it onto start, the agent will receive - * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10; - *

- * BUT image moving left from the starting position will have no impact on the state because - * the agent ran into a wall. The known world stays the same. - * Taking an action after that will have the exact same state but a different action - * making the value of this stateActionPair -9 because the stateAction pair took place on the second - * timestamp, summing up all remaining rewards will be -9... - *

- * How to encounter this problem? - * + * Includes both variants of Monte-Carlo methods + * Default method is First-Visit. + * Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true. * @param */ public class MonteCarloControlFirstVisitEGreedy extends EpisodicLearning { private Map, Double> returnSum; private Map, Integer> returnCount; + private boolean isEveryVisit; - public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) { super(environment, actionSpace, discountFactor, delay); + isEveryVisit = useEveryVisit; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); returnSum = new HashMap<>(); returnCount = new HashMap<>(); } + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + this(environment, actionSpace, discountFactor, epsilon, delay, false); + } + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } @@ -89,35 +77,47 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); - Set> stateActionPairs = new LinkedHashSet<>(); + HashMap, List> stateActionPairs = new LinkedHashMap<>(); + int firstOccurrenceIndex = 0; for(StepResult sr : episode) { - stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction())); + Pair pair = new ImmutablePair<>(sr.getState(), sr.getAction()); + if(!stateActionPairs.containsKey(pair)) { + List l = new ArrayList<>(); + l.add(firstOccurrenceIndex); + stateActionPairs.put(pair, l); + } + + /* + This is the only difference between First-Visit and Every-Visit. + When First-Visit is selected, only the first index of the occurrence is put into the list. + When Every-Visit is selected, every following occurrence is saved + into the list as well. + */ + else if(isEveryVisit) { + stateActionPairs.get(pair).add(firstOccurrenceIndex); + } + ++firstOccurrenceIndex; } - //System.out.println("stateActionPairs " + stateActionPairs.size()); - for(Pair stateActionPair : stateActionPairs) { - int firstOccurenceIndex = 0; - // find first occurance of state action pair - for(StepResult sr : episode) { - if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { - break; + for(Map.Entry, List> entry : stateActionPairs.entrySet()) { + Pair stateActionPair = entry.getKey(); + List firstOccurrences = entry.getValue(); + for(Integer firstOccurrencesIdx : firstOccurrences) { + double G = 0; + for(int l = firstOccurrencesIdx; l < episode.size(); ++l) { + G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx)); } - firstOccurenceIndex++; + // slick trick to add G to the entry. + // if the key does not exists, it will create a new entry with G as default value + returnSum.merge(stateActionPair, G, Double::sum); + returnCount.merge(stateActionPair, 1, Integer::sum); + stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); } - - double G = 0; - for(int l = firstOccurenceIndex; l < episode.size(); ++l) { - G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); - } - // slick trick to add G to the entry. - // if the key does not exists, it will create a new entry with G as default value - returnSum.merge(stateActionPair, G, Double::sum); - returnCount.merge(stateActionPair, 1, Integer::sum); - stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); } } + @Override public void save(ObjectOutputStream oos) throws IOException { super.save(oos); diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index fc33239..895102a 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -51,7 +51,11 @@ public class RLController implements LearningListener { case MC_CONTROL_FIRST_VISIT: learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; - case SARSA_EPISODIC: + case MC_CONTROL_EVERY_VISIT: + learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true); + break; + + case SARSA_ON_POLICY_CONTROL: learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); break; case Q_LEARNING_OFF_POLICY_CONTROL: