Merge remote-tracking branch 'origin/epsilonTest'
This commit is contained in:
commit
33f896ff40
|
@ -5,6 +5,7 @@ import core.Environment;
|
|||
import core.LearningConfig;
|
||||
import core.StepResult;
|
||||
import core.listener.LearningListener;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
|
@ -74,12 +75,35 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
|
||||
protected void dispatchEpisodeStart(){
|
||||
++currentEpisode;
|
||||
/*
|
||||
2f 0.02 => 100
|
||||
1.5f 0.02 => 75
|
||||
1.4f 0.02 => fail
|
||||
1.5f 0.1 => 16 !
|
||||
*/
|
||||
if(this.policy instanceof EpsilonGreedyPolicy){
|
||||
float ep = 1.5f/(float)currentEpisode;
|
||||
if(ep < 0.10) ep = 0;
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
|
||||
System.out.println(ep);
|
||||
}
|
||||
episodesToLearn.decrementAndGet();
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onEpisodeStart();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void dispatchStepEnd() {
|
||||
super.dispatchStepEnd();
|
||||
timestamp++;
|
||||
// TODO: more sophisticated way to check convergence
|
||||
if(timestamp > 300000){
|
||||
System.out.println("converged after: " + currentEpisode + " episode!");
|
||||
interruptLearning();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(){
|
||||
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
||||
|
|
|
@ -12,7 +12,6 @@ import lombok.Setter;
|
|||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
@ -26,6 +25,13 @@ import java.util.concurrent.Executors;
|
|||
*/
|
||||
@Getter
|
||||
public abstract class Learning<A extends Enum>{
|
||||
// TODO: temp testing -> extract to dedicated test
|
||||
protected int checkSum;
|
||||
protected int rewardCheckSum;
|
||||
|
||||
// current discrete timestamp t
|
||||
protected int timestamp;
|
||||
protected int currentEpisode;
|
||||
protected Policy<A> policy;
|
||||
protected DiscreteActionSpace<A> actionSpace;
|
||||
@Setter
|
||||
|
@ -83,6 +89,8 @@ public abstract class Learning<A extends Enum>{
|
|||
|
||||
protected void dispatchLearningEnd() {
|
||||
currentlyLearning = false;
|
||||
System.out.println("Checksum: " + checkSum);
|
||||
System.out.println("Reward Checksum: " + rewardCheckSum);
|
||||
for (LearningListener l : learningListeners) {
|
||||
l.onLearningEnd();
|
||||
}
|
||||
|
|
|
@ -5,5 +5,5 @@ package core.algo;
|
|||
* which RL-algorithm should be used.
|
||||
*/
|
||||
public enum Method {
|
||||
MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
||||
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ import java.util.*;
|
|||
* For example:
|
||||
* <p>
|
||||
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
||||
* image the agent does not collect the food and drops it to the start, the agent will receive
|
||||
* image the agent does not collect the food and does not drop it onto start, the agent will receive
|
||||
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
||||
* <p>
|
||||
* BUT image moving left from the starting position will have no impact on the state because
|
||||
|
@ -30,12 +30,12 @@ import java.util.*;
|
|||
*
|
||||
* @param <A>
|
||||
*/
|
||||
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
|
||||
private Map<Pair<State, A>, Double> returnSum;
|
||||
private Map<Pair<State, A>, Integer> returnCount;
|
||||
|
||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||
|
@ -43,7 +43,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
|||
returnCount = new HashMap<>();
|
||||
}
|
||||
|
||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||
}
|
||||
|
||||
|
@ -58,12 +58,16 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
|||
}
|
||||
sumOfRewards = 0;
|
||||
StepResultEnvironment envResult = null;
|
||||
//TODO extract to learning
|
||||
int timestamp = 0;
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A chosenAction = policy.chooseAction(actionValues);
|
||||
checkSum += chosenAction.ordinal();
|
||||
envResult = environment.step(chosenAction);
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += envResult.getReward();
|
||||
rewardCheckSum += envResult.getReward();
|
||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||
|
||||
state = nextState;
|
||||
|
@ -73,6 +77,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
|||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
timestamp++;
|
||||
dispatchStepEnd();
|
||||
}
|
||||
|
|
@ -46,7 +46,7 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
|||
actionValues = stateActionTable.getActionValues(nextState);
|
||||
A nextAction = policy.chooseAction(actionValues);
|
||||
|
||||
// TD update
|
||||
// td update
|
||||
// target = reward + gamma * Q(nextState, nextAction)
|
||||
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);
|
|
@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
|
|||
import core.algo.EpisodicLearning;
|
||||
import core.algo.Learning;
|
||||
import core.algo.Method;
|
||||
import core.algo.mc.MonteCarloControlEGreedy;
|
||||
import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
|
||||
import core.algo.td.QLearningOffPolicyTDControl;
|
||||
import core.algo.td.SARSA;
|
||||
import core.listener.LearningListener;
|
||||
|
@ -48,8 +48,8 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
|
||||
public void start() {
|
||||
switch(method) {
|
||||
case MC_CONTROL_EGREEDY:
|
||||
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
case MC_CONTROL_FIRST_VISIT:
|
||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
break;
|
||||
case SARSA_EPISODIC:
|
||||
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
||||
|
@ -115,7 +115,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
try {
|
||||
fis = new FileInputStream(fileName);
|
||||
in = new ObjectInputStream(fis);
|
||||
System.out.println("interrup" + Thread.currentThread().getId());
|
||||
System.out.println("interrupt" + Thread.currentThread().getId());
|
||||
learning.interruptLearning();
|
||||
learning.load(in);
|
||||
in.close();
|
||||
|
|
|
@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
@Override
|
||||
public StepResultEnvironment step(DinoAction action) {
|
||||
boolean done = false;
|
||||
int reward = 0;
|
||||
int reward = 1;
|
||||
|
||||
if(action == DinoAction.JUMP){
|
||||
dino.jump();
|
||||
|
@ -74,7 +74,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
spawnNewObstacle();
|
||||
}
|
||||
if(ranIntoObstacle()) {
|
||||
reward = -1;
|
||||
reward = 0;
|
||||
done = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
package example;
|
||||
|
||||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
|
||||
public class DinoSampling {
|
||||
public static void main(String[] args) {
|
||||
for (int i = 0; i < 10 ; i++) {
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
new DinoWorld(false, false),
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(0);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setLearningRate(1f);
|
||||
rl.setNrOfEpisodes(400);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ package example;
|
|||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import core.controller.RLControllerGUI;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
|
||||
|
@ -11,16 +10,16 @@ public class JumpingDino {
|
|||
public static void main(String[] args) {
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
new DinoWorld(false, false),
|
||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(1000);
|
||||
rl.setDiscountFactor(0.9f);
|
||||
rl.setEpsilon(0.1f);
|
||||
rl.setLearningRate(0.5f);
|
||||
rl.setNrOfEpisodes(4000000);
|
||||
rl.setDelay(0);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setLearningRate(1f);
|
||||
rl.setNrOfEpisodes(400);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ public class RunningAnt {
|
|||
|
||||
RLController<AntAction> rl = new RLControllerGUI<>(
|
||||
new AntWorld(3, 3, 0.1),
|
||||
Method.MC_CONTROL_EGREEDY,
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
AntAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import core.controller.RLControllerGUI;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
import org.junit.Test;
|
||||
|
||||
public class MCFirstVisit {
|
||||
|
||||
/**
|
||||
* Test if the action sequence is deterministic
|
||||
*/
|
||||
@Test
|
||||
public void deterministicActionSequence(){
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
new DinoWorld(false, false),
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(10);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.1f);
|
||||
rl.setLearningRate(0.8f);
|
||||
rl.setNrOfEpisodes(4000000);
|
||||
rl.start();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue