diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 49bde9f..05a84d5 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -5,7 +5,7 @@ import core.Environment; import core.LearningConfig; import core.StepResult; import core.listener.LearningListener; -import core.policy.EpsilonGreedyPolicy; +import example.DinoSampling; import lombok.Getter; import lombok.Setter; @@ -104,10 +104,10 @@ public abstract class EpisodicLearning extends Learning imple timestamp++; timestampCurrentEpisode++; // TODO: more sophisticated way to check convergence - if(timestampCurrentEpisode > 300000){ + if(timestampCurrentEpisode > 50000) { converged = true; // t - File file = new File("convergenceAdv.txt"); + File file = new File(DinoSampling.FILE); try { Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND); } catch (IOException e) { diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java index 6372b24..99966fc 100644 --- a/src/main/java/core/algo/Method.java +++ b/src/main/java/core/algo/Method.java @@ -5,5 +5,5 @@ package core.algo; * which RL-algorithm should be used. */ public enum Method { - MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL + MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL } diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java index 5bc93d3..3618eb0 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java @@ -8,37 +8,22 @@ import core.policy.Policy; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; -import java.io.*; -import java.net.URI; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.*; /** - * TODO: Major problem: - * StateActionPairs are only unique accounting for their position in the episode. - * For example: - *

- * startingState -> MOVE_LEFT : very first state action in the episode i = 1 - * image the agent does not collect the food and does not drop it onto start, the agent will receive - * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10; - *

- * BUT image moving left from the starting position will have no impact on the state because - * the agent ran into a wall. The known world stays the same. - * Taking an action after that will have the exact same state but a different action - * making the value of this stateActionPair -9 because the stateAction pair took place on the second - * timestamp, summing up all remaining rewards will be -9... - *

- * How to encounter this problem? - * + * Includes both variants of Monte-Carlo methods + * Default method is First-Visit. + * Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true. * @param */ public class MonteCarloControlFirstVisitEGreedy extends EpisodicLearning { private Map, Double> returnSum; private Map, Integer> returnCount; + private boolean isEveryVisit; // t private float epsilon; @@ -46,8 +31,9 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic private Policy greedyPolicy = new GreedyPolicy<>(); - public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) { super(environment, actionSpace, discountFactor, delay); + isEveryVisit = useEveryVisit; // t this.epsilon = epsilon; this.policy = new EpsilonGreedyPolicy<>(epsilon); @@ -56,6 +42,10 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic returnCount = new HashMap<>(); } + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + this(environment, actionSpace, discountFactor, epsilon, delay, false); + } + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } @@ -104,35 +94,47 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic } // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); - Set> stateActionPairs = new LinkedHashSet<>(); + HashMap, List> stateActionPairs = new LinkedHashMap<>(); + int firstOccurrenceIndex = 0; for(StepResult sr : episode) { - stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction())); + Pair pair = new ImmutablePair<>(sr.getState(), sr.getAction()); + if(!stateActionPairs.containsKey(pair)) { + List l = new ArrayList<>(); + l.add(firstOccurrenceIndex); + stateActionPairs.put(pair, l); + } + + /* + This is the only difference between First-Visit and Every-Visit. + When First-Visit is selected, only the first index of the occurrence is put into the list. + When Every-Visit is selected, every following occurrence is saved + into the list as well. + */ + else if(isEveryVisit) { + stateActionPairs.get(pair).add(firstOccurrenceIndex); + } + ++firstOccurrenceIndex; } - //System.out.println("stateActionPairs " + stateActionPairs.size()); - for(Pair stateActionPair : stateActionPairs) { - int firstOccurenceIndex = 0; - // find first occurance of state action pair - for(StepResult sr : episode) { - if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { - break; + for(Map.Entry, List> entry : stateActionPairs.entrySet()) { + Pair stateActionPair = entry.getKey(); + List firstOccurrences = entry.getValue(); + for(Integer firstOccurrencesIdx : firstOccurrences) { + double G = 0; + for(int l = firstOccurrencesIdx; l < episode.size(); ++l) { + G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx)); } - firstOccurenceIndex++; + // slick trick to add G to the entry. + // if the key does not exists, it will create a new entry with G as default value + returnSum.merge(stateActionPair, G, Double::sum); + returnCount.merge(stateActionPair, 1, Integer::sum); + stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); } - - double G = 0; - for(int l = firstOccurenceIndex; l < episode.size(); ++l) { - G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); - } - // slick trick to add G to the entry. - // if the key does not exists, it will create a new entry with G as default value - returnSum.merge(stateActionPair, G, Double::sum); - returnCount.merge(stateActionPair, 1, Integer::sum); - stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); } } + @Override public void save(ObjectOutputStream oos) throws IOException { super.save(oos); diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index fc33239..895102a 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -51,7 +51,11 @@ public class RLController implements LearningListener { case MC_CONTROL_FIRST_VISIT: learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; - case SARSA_EPISODIC: + case MC_CONTROL_EVERY_VISIT: + learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true); + break; + + case SARSA_ON_POLICY_CONTROL: learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); break; case Q_LEARNING_OFF_POLICY_CONTROL: