add Every-Visit Monte-Carlo
This commit is contained in:
parent
740289ee2b
commit
9a3452ff9c
|
@ -5,7 +5,6 @@ import core.Environment;
|
||||||
import core.LearningConfig;
|
import core.LearningConfig;
|
||||||
import core.StepResult;
|
import core.StepResult;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
import core.policy.EpsilonGreedyPolicy;
|
|
||||||
import example.DinoSampling;
|
import example.DinoSampling;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
@ -125,6 +124,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
private void startLearning(){
|
private void startLearning(){
|
||||||
dispatchLearningStart();
|
dispatchLearningStart();
|
||||||
while(episodesToLearn.get() > 0){
|
while(episodesToLearn.get() > 0){
|
||||||
|
|
||||||
dispatchEpisodeStart();
|
dispatchEpisodeStart();
|
||||||
timestampCurrentEpisode = 0;
|
timestampCurrentEpisode = 0;
|
||||||
nextEpisode();
|
nextEpisode();
|
||||||
|
|
|
@ -5,5 +5,5 @@ package core.algo;
|
||||||
* which RL-algorithm should be used.
|
* which RL-algorithm should be used.
|
||||||
*/
|
*/
|
||||||
public enum Method {
|
public enum Method {
|
||||||
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,52 +3,40 @@ package core.algo.mc;
|
||||||
import core.*;
|
import core.*;
|
||||||
import core.algo.EpisodicLearning;
|
import core.algo.EpisodicLearning;
|
||||||
import core.policy.EpsilonGreedyPolicy;
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
import core.policy.GreedyPolicy;
|
|
||||||
import core.policy.Policy;
|
|
||||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.IOException;
|
||||||
import java.net.URI;
|
import java.io.ObjectInputStream;
|
||||||
import java.nio.file.Files;
|
import java.io.ObjectOutputStream;
|
||||||
import java.nio.file.Path;
|
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.nio.file.StandardOpenOption;
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO: Major problem:
|
* Includes both variants of Monte-Carlo methods
|
||||||
* StateActionPairs are only unique accounting for their position in the episode.
|
* Default method is First-Visit.
|
||||||
* For example:
|
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
|
||||||
* <p>
|
|
||||||
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
|
||||||
* image the agent does not collect the food and does not drop it onto start, the agent will receive
|
|
||||||
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
|
||||||
* <p>
|
|
||||||
* BUT image moving left from the starting position will have no impact on the state because
|
|
||||||
* the agent ran into a wall. The known world stays the same.
|
|
||||||
* Taking an action after that will have the exact same state but a different action
|
|
||||||
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
|
|
||||||
* timestamp, summing up all remaining rewards will be -9...
|
|
||||||
* <p>
|
|
||||||
* How to encounter this problem?
|
|
||||||
*
|
|
||||||
* @param <A>
|
* @param <A>
|
||||||
*/
|
*/
|
||||||
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
|
||||||
private Map<Pair<State, A>, Double> returnSum;
|
private Map<Pair<State, A>, Double> returnSum;
|
||||||
private Map<Pair<State, A>, Integer> returnCount;
|
private Map<Pair<State, A>, Integer> returnCount;
|
||||||
|
private boolean isEveryVisit;
|
||||||
|
|
||||||
|
|
||||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
|
isEveryVisit = useEveryVisit;
|
||||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||||
returnSum = new HashMap<>();
|
returnSum = new HashMap<>();
|
||||||
returnCount = new HashMap<>();
|
returnCount = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||||
|
this(environment, actionSpace, discountFactor, epsilon, delay, false);
|
||||||
|
}
|
||||||
|
|
||||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||||
}
|
}
|
||||||
|
@ -89,35 +77,47 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
||||||
|
|
||||||
|
|
||||||
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
int firstOccurrenceIndex = 0;
|
||||||
for(StepResult<A> sr : episode) {
|
for(StepResult<A> sr : episode) {
|
||||||
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
|
Pair<State, A> pair = new ImmutablePair<>(sr.getState(), sr.getAction());
|
||||||
|
if(!stateActionPairs.containsKey(pair)) {
|
||||||
|
List<Integer> l = new ArrayList<>();
|
||||||
|
l.add(firstOccurrenceIndex);
|
||||||
|
stateActionPairs.put(pair, l);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
This is the only difference between First-Visit and Every-Visit.
|
||||||
|
When First-Visit is selected, only the first index of the occurrence is put into the list.
|
||||||
|
When Every-Visit is selected, every following occurrence is saved
|
||||||
|
into the list as well.
|
||||||
|
*/
|
||||||
|
else if(isEveryVisit) {
|
||||||
|
stateActionPairs.get(pair).add(firstOccurrenceIndex);
|
||||||
|
}
|
||||||
|
++firstOccurrenceIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||||
for(Pair<State, A> stateActionPair : stateActionPairs) {
|
for(Map.Entry<Pair<State, A>, List<Integer>> entry : stateActionPairs.entrySet()) {
|
||||||
int firstOccurenceIndex = 0;
|
Pair<State, A> stateActionPair = entry.getKey();
|
||||||
// find first occurance of state action pair
|
List<Integer> firstOccurrences = entry.getValue();
|
||||||
for(StepResult<A> sr : episode) {
|
for(Integer firstOccurrencesIdx : firstOccurrences) {
|
||||||
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
double G = 0;
|
||||||
break;
|
for(int l = firstOccurrencesIdx; l < episode.size(); ++l) {
|
||||||
|
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx));
|
||||||
}
|
}
|
||||||
firstOccurenceIndex++;
|
// slick trick to add G to the entry.
|
||||||
|
// if the key does not exists, it will create a new entry with G as default value
|
||||||
|
returnSum.merge(stateActionPair, G, Double::sum);
|
||||||
|
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||||
|
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||||
}
|
}
|
||||||
|
|
||||||
double G = 0;
|
|
||||||
for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
|
|
||||||
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
|
|
||||||
}
|
|
||||||
// slick trick to add G to the entry.
|
|
||||||
// if the key does not exists, it will create a new entry with G as default value
|
|
||||||
returnSum.merge(stateActionPair, G, Double::sum);
|
|
||||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
|
||||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void save(ObjectOutputStream oos) throws IOException {
|
public void save(ObjectOutputStream oos) throws IOException {
|
||||||
super.save(oos);
|
super.save(oos);
|
||||||
|
|
|
@ -51,7 +51,11 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
case MC_CONTROL_FIRST_VISIT:
|
case MC_CONTROL_FIRST_VISIT:
|
||||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||||
break;
|
break;
|
||||||
case SARSA_EPISODIC:
|
case MC_CONTROL_EVERY_VISIT:
|
||||||
|
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case SARSA_ON_POLICY_CONTROL:
|
||||||
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
||||||
break;
|
break;
|
||||||
case Q_LEARNING_OFF_POLICY_CONTROL:
|
case Q_LEARNING_OFF_POLICY_CONTROL:
|
||||||
|
|
Loading…
Reference in New Issue