add Every-Visit Monte-Carlo
This commit is contained in:
parent
740289ee2b
commit
9a3452ff9c
|
@ -5,7 +5,6 @@ import core.Environment;
|
|||
import core.LearningConfig;
|
||||
import core.StepResult;
|
||||
import core.listener.LearningListener;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import example.DinoSampling;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
@ -125,6 +124,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
private void startLearning(){
|
||||
dispatchLearningStart();
|
||||
while(episodesToLearn.get() > 0){
|
||||
|
||||
dispatchEpisodeStart();
|
||||
timestampCurrentEpisode = 0;
|
||||
nextEpisode();
|
||||
|
|
|
@ -5,5 +5,5 @@ package core.algo;
|
|||
* which RL-algorithm should be used.
|
||||
*/
|
||||
public enum Method {
|
||||
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
||||
MC_CONTROL_FIRST_VISIT, MC_CONTROL_EVERY_VISIT, SARSA_ON_POLICY_CONTROL, Q_LEARNING_OFF_POLICY_CONTROL
|
||||
}
|
||||
|
|
|
@ -3,52 +3,40 @@ package core.algo.mc;
|
|||
import core.*;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import core.policy.GreedyPolicy;
|
||||
import core.policy.Policy;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.URI;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* TODO: Major problem:
|
||||
* StateActionPairs are only unique accounting for their position in the episode.
|
||||
* For example:
|
||||
* <p>
|
||||
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
||||
* image the agent does not collect the food and does not drop it onto start, the agent will receive
|
||||
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
||||
* <p>
|
||||
* BUT image moving left from the starting position will have no impact on the state because
|
||||
* the agent ran into a wall. The known world stays the same.
|
||||
* Taking an action after that will have the exact same state but a different action
|
||||
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
|
||||
* timestamp, summing up all remaining rewards will be -9...
|
||||
* <p>
|
||||
* How to encounter this problem?
|
||||
*
|
||||
* Includes both variants of Monte-Carlo methods
|
||||
* Default method is First-Visit.
|
||||
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
|
||||
* @param <A>
|
||||
*/
|
||||
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
|
||||
private Map<Pair<State, A>, Double> returnSum;
|
||||
private Map<Pair<State, A>, Integer> returnCount;
|
||||
private boolean isEveryVisit;
|
||||
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
isEveryVisit = useEveryVisit;
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||
returnSum = new HashMap<>();
|
||||
returnCount = new HashMap<>();
|
||||
}
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
this(environment, actionSpace, discountFactor, epsilon, delay, false);
|
||||
}
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||
}
|
||||
|
@ -89,35 +77,47 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
|||
|
||||
|
||||
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();
|
||||
|
||||
int firstOccurrenceIndex = 0;
|
||||
for(StepResult<A> sr : episode) {
|
||||
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
|
||||
Pair<State, A> pair = new ImmutablePair<>(sr.getState(), sr.getAction());
|
||||
if(!stateActionPairs.containsKey(pair)) {
|
||||
List<Integer> l = new ArrayList<>();
|
||||
l.add(firstOccurrenceIndex);
|
||||
stateActionPairs.put(pair, l);
|
||||
}
|
||||
|
||||
/*
|
||||
This is the only difference between First-Visit and Every-Visit.
|
||||
When First-Visit is selected, only the first index of the occurrence is put into the list.
|
||||
When Every-Visit is selected, every following occurrence is saved
|
||||
into the list as well.
|
||||
*/
|
||||
else if(isEveryVisit) {
|
||||
stateActionPairs.get(pair).add(firstOccurrenceIndex);
|
||||
}
|
||||
++firstOccurrenceIndex;
|
||||
}
|
||||
|
||||
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||
for(Pair<State, A> stateActionPair : stateActionPairs) {
|
||||
int firstOccurenceIndex = 0;
|
||||
// find first occurance of state action pair
|
||||
for(StepResult<A> sr : episode) {
|
||||
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
||||
break;
|
||||
for(Map.Entry<Pair<State, A>, List<Integer>> entry : stateActionPairs.entrySet()) {
|
||||
Pair<State, A> stateActionPair = entry.getKey();
|
||||
List<Integer> firstOccurrences = entry.getValue();
|
||||
for(Integer firstOccurrencesIdx : firstOccurrences) {
|
||||
double G = 0;
|
||||
for(int l = firstOccurrencesIdx; l < episode.size(); ++l) {
|
||||
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurrencesIdx));
|
||||
}
|
||||
firstOccurenceIndex++;
|
||||
// slick trick to add G to the entry.
|
||||
// if the key does not exists, it will create a new entry with G as default value
|
||||
returnSum.merge(stateActionPair, G, Double::sum);
|
||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
}
|
||||
|
||||
double G = 0;
|
||||
for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
|
||||
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
|
||||
}
|
||||
// slick trick to add G to the entry.
|
||||
// if the key does not exists, it will create a new entry with G as default value
|
||||
returnSum.merge(stateActionPair, G, Double::sum);
|
||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void save(ObjectOutputStream oos) throws IOException {
|
||||
super.save(oos);
|
||||
|
|
|
@ -51,7 +51,11 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
case MC_CONTROL_FIRST_VISIT:
|
||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
break;
|
||||
case SARSA_EPISODIC:
|
||||
case MC_CONTROL_EVERY_VISIT:
|
||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
|
||||
break;
|
||||
|
||||
case SARSA_ON_POLICY_CONTROL:
|
||||
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
||||
break;
|
||||
case Q_LEARNING_OFF_POLICY_CONTROL:
|
||||
|
|
Loading…
Reference in New Issue