diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index 49bde9f..05a84d5 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -5,7 +5,7 @@ import core.Environment;
import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
-import core.policy.EpsilonGreedyPolicy;
+import example.DinoSampling;
import lombok.Getter;
import lombok.Setter;
@@ -104,10 +104,10 @@ public abstract class EpisodicLearning extends Learning imple
timestamp++;
timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence
- if(timestampCurrentEpisode > 300000){
+ if(timestampCurrentEpisode > 50000) {
converged = true;
// t
- File file = new File("convergenceAdv.txt");
+ File file = new File(DinoSampling.FILE);
try {
Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java
index 6372b24..99966fc 100644
--- a/src/main/java/core/algo/Method.java
+++ b/src/main/java/core/algo/Method.java
@@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used.
*/
public enum Method {
- MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
+ MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL
}
diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
index 5bc93d3..3618eb0 100644
--- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
+++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
@@ -8,37 +8,22 @@ import core.policy.Policy;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
-import java.io.*;
-import java.net.URI;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.nio.file.StandardOpenOption;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.util.*;
/**
- * TODO: Major problem:
- * StateActionPairs are only unique accounting for their position in the episode.
- * For example:
- *
- * startingState -> MOVE_LEFT : very first state action in the episode i = 1
- * image the agent does not collect the food and does not drop it onto start, the agent will receive
- * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
- *
- * BUT image moving left from the starting position will have no impact on the state because
- * the agent ran into a wall. The known world stays the same.
- * Taking an action after that will have the exact same state but a different action
- * making the value of this stateActionPair -9 because the stateAction pair took place on the second
- * timestamp, summing up all remaining rewards will be -9...
- *
- * How to encounter this problem?
- *
+ * Includes both variants of Monte-Carlo methods
+ * Default method is First-Visit.
+ * Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
* @param
*/
public class MonteCarloControlFirstVisitEGreedy extends EpisodicLearning {
private Map, Double> returnSum;
private Map, Integer> returnCount;
+ private boolean isEveryVisit;
// t
private float epsilon;
@@ -46,8 +31,9 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
private Policy greedyPolicy = new GreedyPolicy<>();
- public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
+ public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
super(environment, actionSpace, discountFactor, delay);
+ isEveryVisit = useEveryVisit;
// t
this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
@@ -56,6 +42,10 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
returnCount = new HashMap<>();
}
+ public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
+ this(environment, actionSpace, discountFactor, epsilon, delay, false);
+ }
+
public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@@ -104,35 +94,47 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
- Set> stateActionPairs = new LinkedHashSet<>();
+ HashMap, List> stateActionPairs = new LinkedHashMap<>();
+ int firstOccurrenceIndex = 0;
for(StepResult sr : episode) {
- stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
+ Pair pair = new ImmutablePair<>(sr.getState(), sr.getAction());
+ if(!stateActionPairs.containsKey(pair)) {
+ List l = new ArrayList<>();
+ l.add(firstOccurrenceIndex);
+ stateActionPairs.put(pair, l);
+ }
+
+ /*
+ This is the only difference between First-Visit and Every-Visit.
+ When First-Visit is selected, only the first index of the occurrence is put into the list.
+ When Every-Visit is selected, every following occurrence is saved
+ into the list as well.
+ */
+ else if(isEveryVisit) {
+ stateActionPairs.get(pair).add(firstOccurrenceIndex);
+ }
+ ++firstOccurrenceIndex;
}
-
//System.out.println("stateActionPairs " + stateActionPairs.size());
- for(Pair stateActionPair : stateActionPairs) {
- int firstOccurenceIndex = 0;
- // find first occurance of state action pair
- for(StepResult sr : episode) {
- if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
- break;
+ for(Map.Entry, List> entry : stateActionPairs.entrySet()) {
+ Pair stateActionPair = entry.getKey();
+ List firstOccurrences = entry.getValue();
+ for(Integer firstOccurrencesIdx : firstOccurrences) {
+ double G = 0;
+ for(int l = firstOccurrencesIdx; l < episode.size(); ++l) {
+ G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx));
}
- firstOccurenceIndex++;
+ // slick trick to add G to the entry.
+ // if the key does not exists, it will create a new entry with G as default value
+ returnSum.merge(stateActionPair, G, Double::sum);
+ returnCount.merge(stateActionPair, 1, Integer::sum);
+ stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
-
- double G = 0;
- for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
- G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
- }
- // slick trick to add G to the entry.
- // if the key does not exists, it will create a new entry with G as default value
- returnSum.merge(stateActionPair, G, Double::sum);
- returnCount.merge(stateActionPair, 1, Integer::sum);
- stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
}
+
@Override
public void save(ObjectOutputStream oos) throws IOException {
super.save(oos);
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index fc33239..895102a 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -51,7 +51,11 @@ public class RLController implements LearningListener {
case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
- case SARSA_EPISODIC:
+ case MC_CONTROL_EVERY_VISIT:
+ learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
+ break;
+
+ case SARSA_ON_POLICY_CONTROL:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
case Q_LEARNING_OFF_POLICY_CONTROL: