add Every-Visit Monte-Carlo

This commit is contained in:
Jan Löwenstrom 2020-04-02 15:56:11 +02:00
parent ee1d62842d
commit 6477251545
4 changed files with 54 additions and 48 deletions

View File

@ -5,7 +5,7 @@ import core.Environment;
import core.LearningConfig; import core.LearningConfig;
import core.StepResult; import core.StepResult;
import core.listener.LearningListener; import core.listener.LearningListener;
import core.policy.EpsilonGreedyPolicy; import example.DinoSampling;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
@ -104,10 +104,10 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
timestamp++; timestamp++;
timestampCurrentEpisode++; timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence // TODO: more sophisticated way to check convergence
if(timestampCurrentEpisode > 300000){ if(timestampCurrentEpisode > 50000) {
converged = true; converged = true;
// t // t
File file = new File("convergenceAdv.txt"); File file = new File(DinoSampling.FILE);
try { try {
Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND); Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND);
} catch (IOException e) { } catch (IOException e) {

View File

@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used. * which RL-algorithm should be used.
*/ */
public enum Method { public enum Method {
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL
} }

View File

@ -8,37 +8,22 @@ import core.policy.Policy;
import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import java.io.*; import java.io.IOException;
import java.net.URI; import java.io.ObjectInputStream;
import java.nio.file.Files; import java.io.ObjectOutputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.*; import java.util.*;
/** /**
* TODO: Major problem: * Includes both variants of Monte-Carlo methods
* StateActionPairs are only unique accounting for their position in the episode. * Default method is First-Visit.
* For example: * Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
* <p>
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
* image the agent does not collect the food and does not drop it onto start, the agent will receive
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
* <p>
* BUT image moving left from the starting position will have no impact on the state because
* the agent ran into a wall. The known world stays the same.
* Taking an action after that will have the exact same state but a different action
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
* timestamp, summing up all remaining rewards will be -9...
* <p>
* How to encounter this problem?
*
* @param <A> * @param <A>
*/ */
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> { public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
private Map<Pair<State, A>, Double> returnSum; private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount; private Map<Pair<State, A>, Integer> returnCount;
private boolean isEveryVisit;
// t // t
private float epsilon; private float epsilon;
@ -46,8 +31,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
private Policy<A> greedyPolicy = new GreedyPolicy<>(); private Policy<A> greedyPolicy = new GreedyPolicy<>();
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) { public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
super(environment, actionSpace, discountFactor, delay); super(environment, actionSpace, discountFactor, delay);
isEveryVisit = useEveryVisit;
// t // t
this.epsilon = epsilon; this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon); this.policy = new EpsilonGreedyPolicy<>(epsilon);
@ -56,6 +42,10 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
returnCount = new HashMap<>(); returnCount = new HashMap<>();
} }
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
this(environment, actionSpace, discountFactor, epsilon, delay, false);
}
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) { public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
} }
@ -104,35 +94,47 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
} }
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>(); HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();
int firstOccurrenceIndex = 0;
for(StepResult<A> sr : episode) { for(StepResult<A> sr : episode) {
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction())); Pair<State, A> pair = new ImmutablePair<>(sr.getState(), sr.getAction());
if(!stateActionPairs.containsKey(pair)) {
List<Integer> l = new ArrayList<>();
l.add(firstOccurrenceIndex);
stateActionPairs.put(pair, l);
}
/*
This is the only difference between First-Visit and Every-Visit.
When First-Visit is selected, only the first index of the occurrence is put into the list.
When Every-Visit is selected, every following occurrence is saved
into the list as well.
*/
else if(isEveryVisit) {
stateActionPairs.get(pair).add(firstOccurrenceIndex);
}
++firstOccurrenceIndex;
} }
//System.out.println("stateActionPairs " + stateActionPairs.size()); //System.out.println("stateActionPairs " + stateActionPairs.size());
for(Pair<State, A> stateActionPair : stateActionPairs) { for(Map.Entry<Pair<State, A>, List<Integer>> entry : stateActionPairs.entrySet()) {
int firstOccurenceIndex = 0; Pair<State, A> stateActionPair = entry.getKey();
// find first occurance of state action pair List<Integer> firstOccurrences = entry.getValue();
for(StepResult<A> sr : episode) { for(Integer firstOccurrencesIdx : firstOccurrences) {
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { double G = 0;
break; for(int l = firstOccurrencesIdx; l < episode.size(); ++l) {
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx));
} }
firstOccurenceIndex++; // slick trick to add G to the entry.
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
} }
double G = 0;
for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
}
// slick trick to add G to the entry.
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
} }
} }
@Override @Override
public void save(ObjectOutputStream oos) throws IOException { public void save(ObjectOutputStream oos) throws IOException {
super.save(oos); super.save(oos);

View File

@ -51,7 +51,11 @@ public class RLController<A extends Enum> implements LearningListener {
case MC_CONTROL_FIRST_VISIT: case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break; break;
case SARSA_EPISODIC: case MC_CONTROL_EVERY_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
break;
case SARSA_ON_POLICY_CONTROL:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break; break;
case Q_LEARNING_OFF_POLICY_CONTROL: case Q_LEARNING_OFF_POLICY_CONTROL: