first epsilon decaying method

This commit is contained in:
Jan Löwenstrom 2020-02-27 15:29:15 +01:00
parent cff1a4e531
commit 0e4f52a48e
12 changed files with 115 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import core.Environment;
import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
import core.policy.EpsilonGreedyPolicy;
import lombok.Getter;
import lombok.Setter;
@ -74,12 +75,35 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchEpisodeStart(){
++currentEpisode;
/*
2f 0.02 => 100
1.5f 0.02 => 75
1.4f 0.02 => fail
1.5f 0.1 => 16 !
*/
if(this.policy instanceof EpsilonGreedyPolicy){
float ep = 1.5f/(float)currentEpisode;
if(ep < 0.10) ep = 0;
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
System.out.println(ep);
}
episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){
l.onEpisodeStart();
}
}
@Override
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
// TODO: more sophisticated way to check convergence
if(timestamp > 300000){
System.out.println("converged after: " + currentEpisode + " episode!");
interruptLearning();
}
}
@Override
public void learn(){
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);

View File

@ -12,7 +12,6 @@ import lombok.Setter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -26,6 +25,13 @@ import java.util.concurrent.Executors;
*/
@Getter
public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;
protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace;
@Setter
@ -83,6 +89,8 @@ public abstract class Learning<A extends Enum>{
protected void dispatchLearningEnd() {
currentlyLearning = false;
System.out.println("Checksum: " + checkSum);
System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}

View File

@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used.
*/
public enum Method {
MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
}

View File

@ -17,7 +17,7 @@ import java.util.*;
* For example:
* <p>
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
* image the agent does not collect the food and drops it to the start, the agent will receive
* image the agent does not collect the food and does not drop it onto start, the agent will receive
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
* <p>
* BUT image moving left from the starting position will have no impact on the state because
@ -30,12 +30,12 @@ import java.util.*;
*
* @param <A>
*/
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
@ -43,7 +43,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
returnCount = new HashMap<>();
}
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@ -58,12 +58,16 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
}
sumOfRewards = 0;
StepResultEnvironment envResult = null;
//TODO extract to learning
int timestamp = 0;
while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
checkSum += chosenAction.ordinal();
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
rewardCheckSum += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
state = nextState;
@ -73,6 +77,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
} catch (InterruptedException e) {
e.printStackTrace();
}
timestamp++;
dispatchStepEnd();
}

View File

@ -46,7 +46,7 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
actionValues = stateActionTable.getActionValues(nextState);
A nextAction = policy.chooseAction(actionValues);
// TD update
// td update
// target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action);
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);

View File

@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloControlEGreedy;
import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
import core.algo.td.QLearningOffPolicyTDControl;
import core.algo.td.SARSA;
import core.listener.LearningListener;
@ -48,8 +48,8 @@ public class RLController<A extends Enum> implements LearningListener {
public void start() {
switch(method) {
case MC_CONTROL_EGREEDY:
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case SARSA_EPISODIC:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
@ -115,7 +115,7 @@ public class RLController<A extends Enum> implements LearningListener {
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
System.out.println("interrup" + Thread.currentThread().getId());
System.out.println("interrupt" + Thread.currentThread().getId());
learning.interruptLearning();
learning.load(in);
in.close();

View File

@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
@Override
public StepResultEnvironment step(DinoAction action) {
boolean done = false;
int reward = 0;
int reward = 1;
if(action == DinoAction.JUMP){
dino.jump();
@ -74,7 +74,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
spawnNewObstacle();
}
if(ranIntoObstacle()) {
reward = -1;
reward = 0;
done = true;
}

View File

@ -0,0 +1,27 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
public class DinoSampling {
public static void main(String[] args) {
for (int i = 0; i < 10 ; i++) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.start();
}
}
}

View File

@ -3,7 +3,6 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
@ -11,16 +10,16 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLControllerGUI<>(
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(1000);
rl.setDiscountFactor(0.9f);
rl.setEpsilon(0.1f);
rl.setLearningRate(0.5f);
rl.setNrOfEpisodes(4000000);
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.start();
}
}

View File

@ -13,7 +13,7 @@ public class RunningAnt {
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorld(3, 3, 0.1),
Method.MC_CONTROL_EGREEDY,
Method.MC_CONTROL_FIRST_VISIT,
AntAction.values());
rl.setDelay(200);

View File

@ -0,0 +1,30 @@
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import org.junit.Test;
public class MCFirstVisit {
/**
* Test if the action sequence is deterministic
*/
@Test
public void deterministicActionSequence(){
RNG.setSeed(55);
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(10);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.1f);
rl.setLearningRate(0.8f);
rl.setNrOfEpisodes(4000000);
rl.start();
}
}