first epsilon decaying method

This commit is contained in:
Jan Löwenstrom 2020-02-27 15:29:15 +01:00
parent cff1a4e531
commit 0e4f52a48e
12 changed files with 115 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import core.Environment;
import core.LearningConfig; import core.LearningConfig;
import core.StepResult; import core.StepResult;
import core.listener.LearningListener; import core.listener.LearningListener;
import core.policy.EpsilonGreedyPolicy;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
@ -74,12 +75,35 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchEpisodeStart(){ protected void dispatchEpisodeStart(){
++currentEpisode; ++currentEpisode;
/*
2f 0.02 => 100
1.5f 0.02 => 75
1.4f 0.02 => fail
1.5f 0.1 => 16 !
*/
if(this.policy instanceof EpsilonGreedyPolicy){
float ep = 1.5f/(float)currentEpisode;
if(ep < 0.10) ep = 0;
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
System.out.println(ep);
}
episodesToLearn.decrementAndGet(); episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){ for(LearningListener l: learningListeners){
l.onEpisodeStart(); l.onEpisodeStart();
} }
} }
@Override
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
// TODO: more sophisticated way to check convergence
if(timestamp > 300000){
System.out.println("converged after: " + currentEpisode + " episode!");
interruptLearning();
}
}
@Override @Override
public void learn(){ public void learn(){
learn(LearningConfig.DEFAULT_NR_OF_EPISODES); learn(LearningConfig.DEFAULT_NR_OF_EPISODES);

View File

@ -12,7 +12,6 @@ import lombok.Setter;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectInputStream; import java.io.ObjectInputStream;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -26,6 +25,13 @@ import java.util.concurrent.Executors;
*/ */
@Getter @Getter
public abstract class Learning<A extends Enum>{ public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;
protected Policy<A> policy; protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace; protected DiscreteActionSpace<A> actionSpace;
@Setter @Setter
@ -83,6 +89,8 @@ public abstract class Learning<A extends Enum>{
protected void dispatchLearningEnd() { protected void dispatchLearningEnd() {
currentlyLearning = false; currentlyLearning = false;
System.out.println("Checksum: " + checkSum);
System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) { for (LearningListener l : learningListeners) {
l.onLearningEnd(); l.onLearningEnd();
} }

View File

@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used. * which RL-algorithm should be used.
*/ */
public enum Method { public enum Method {
MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
} }

View File

@ -17,7 +17,7 @@ import java.util.*;
* For example: * For example:
* <p> * <p>
* startingState -> MOVE_LEFT : very first state action in the episode i = 1 * startingState -> MOVE_LEFT : very first state action in the episode i = 1
* image the agent does not collect the food and drops it to the start, the agent will receive * image the agent does not collect the food and does not drop it onto start, the agent will receive
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10; * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
* <p> * <p>
* BUT image moving left from the starting position will have no impact on the state because * BUT image moving left from the starting position will have no impact on the state because
@ -30,12 +30,12 @@ import java.util.*;
* *
* @param <A> * @param <A>
*/ */
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> { public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
private Map<Pair<State, A>, Double> returnSum; private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount; private Map<Pair<State, A>, Integer> returnCount;
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) { public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay); super(environment, actionSpace, discountFactor, delay);
this.policy = new EpsilonGreedyPolicy<>(epsilon); this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
@ -43,7 +43,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
returnCount = new HashMap<>(); returnCount = new HashMap<>();
} }
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) { public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
} }
@ -58,12 +58,16 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
} }
sumOfRewards = 0; sumOfRewards = 0;
StepResultEnvironment envResult = null; StepResultEnvironment envResult = null;
//TODO extract to learning
int timestamp = 0;
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state); Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues); A chosenAction = policy.chooseAction(actionValues);
checkSum += chosenAction.ordinal();
envResult = environment.step(chosenAction); envResult = environment.step(chosenAction);
State nextState = envResult.getState(); State nextState = envResult.getState();
sumOfRewards += envResult.getReward(); sumOfRewards += envResult.getReward();
rewardCheckSum += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
state = nextState; state = nextState;
@ -73,6 +77,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); e.printStackTrace();
} }
timestamp++;
dispatchStepEnd(); dispatchStepEnd();
} }

View File

@ -46,7 +46,7 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
actionValues = stateActionTable.getActionValues(nextState); actionValues = stateActionTable.getActionValues(nextState);
A nextAction = policy.chooseAction(actionValues); A nextAction = policy.chooseAction(actionValues);
// TD update // td update
// target = reward + gamma * Q(nextState, nextAction) // target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action); double currentQValue = stateActionTable.getActionValues(state).get(action);
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction); double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);

View File

@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning; import core.algo.EpisodicLearning;
import core.algo.Learning; import core.algo.Learning;
import core.algo.Method; import core.algo.Method;
import core.algo.mc.MonteCarloControlEGreedy; import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
import core.algo.td.QLearningOffPolicyTDControl; import core.algo.td.QLearningOffPolicyTDControl;
import core.algo.td.SARSA; import core.algo.td.SARSA;
import core.listener.LearningListener; import core.listener.LearningListener;
@ -48,8 +48,8 @@ public class RLController<A extends Enum> implements LearningListener {
public void start() { public void start() {
switch(method) { switch(method) {
case MC_CONTROL_EGREEDY: case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break; break;
case SARSA_EPISODIC: case SARSA_EPISODIC:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
@ -115,7 +115,7 @@ public class RLController<A extends Enum> implements LearningListener {
try { try {
fis = new FileInputStream(fileName); fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis); in = new ObjectInputStream(fis);
System.out.println("interrup" + Thread.currentThread().getId()); System.out.println("interrupt" + Thread.currentThread().getId());
learning.interruptLearning(); learning.interruptLearning();
learning.load(in); learning.load(in);
in.close(); in.close();

View File

@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
@Override @Override
public StepResultEnvironment step(DinoAction action) { public StepResultEnvironment step(DinoAction action) {
boolean done = false; boolean done = false;
int reward = 0; int reward = 1;
if(action == DinoAction.JUMP){ if(action == DinoAction.JUMP){
dino.jump(); dino.jump();
@ -74,7 +74,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
spawnNewObstacle(); spawnNewObstacle();
} }
if(ranIntoObstacle()) { if(ranIntoObstacle()) {
reward = -1; reward = 0;
done = true; done = true;
} }

View File

@ -0,0 +1,27 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
public class DinoSampling {
public static void main(String[] args) {
for (int i = 0; i < 10 ; i++) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.start();
}
}
}

View File

@ -3,7 +3,6 @@ package example;
import core.RNG; import core.RNG;
import core.algo.Method; import core.algo.Method;
import core.controller.RLController; import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorld;
@ -11,16 +10,16 @@ public class JumpingDino {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(55); RNG.setSeed(55);
RLController<DinoAction> rl = new RLControllerGUI<>( RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false), new DinoWorld(false, false),
Method.Q_LEARNING_OFF_POLICY_CONTROL, Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(1000); rl.setDelay(0);
rl.setDiscountFactor(0.9f); rl.setDiscountFactor(1f);
rl.setEpsilon(0.1f); rl.setEpsilon(0.15f);
rl.setLearningRate(0.5f); rl.setLearningRate(1f);
rl.setNrOfEpisodes(4000000); rl.setNrOfEpisodes(400);
rl.start(); rl.start();
} }
} }

View File

@ -13,7 +13,7 @@ public class RunningAnt {
RLController<AntAction> rl = new RLControllerGUI<>( RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorld(3, 3, 0.1), new AntWorld(3, 3, 0.1),
Method.MC_CONTROL_EGREEDY, Method.MC_CONTROL_FIRST_VISIT,
AntAction.values()); AntAction.values());
rl.setDelay(200); rl.setDelay(200);

View File

@ -0,0 +1,30 @@
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import org.junit.Test;
public class MCFirstVisit {
/**
* Test if the action sequence is deterministic
*/
@Test
public void deterministicActionSequence(){
RNG.setSeed(55);
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(10);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.1f);
rl.setLearningRate(0.8f);
rl.setNrOfEpisodes(4000000);
rl.start();
}
}