Merge remote-tracking branch 'origin/epsilonTest'
This commit is contained in:
commit
33f896ff40
|
@ -5,6 +5,7 @@ import core.Environment;
|
||||||
import core.LearningConfig;
|
import core.LearningConfig;
|
||||||
import core.StepResult;
|
import core.StepResult;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
|
@ -74,12 +75,35 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
|
|
||||||
protected void dispatchEpisodeStart(){
|
protected void dispatchEpisodeStart(){
|
||||||
++currentEpisode;
|
++currentEpisode;
|
||||||
|
/*
|
||||||
|
2f 0.02 => 100
|
||||||
|
1.5f 0.02 => 75
|
||||||
|
1.4f 0.02 => fail
|
||||||
|
1.5f 0.1 => 16 !
|
||||||
|
*/
|
||||||
|
if(this.policy instanceof EpsilonGreedyPolicy){
|
||||||
|
float ep = 1.5f/(float)currentEpisode;
|
||||||
|
if(ep < 0.10) ep = 0;
|
||||||
|
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
|
||||||
|
System.out.println(ep);
|
||||||
|
}
|
||||||
episodesToLearn.decrementAndGet();
|
episodesToLearn.decrementAndGet();
|
||||||
for(LearningListener l: learningListeners){
|
for(LearningListener l: learningListeners){
|
||||||
l.onEpisodeStart();
|
l.onEpisodeStart();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void dispatchStepEnd() {
|
||||||
|
super.dispatchStepEnd();
|
||||||
|
timestamp++;
|
||||||
|
// TODO: more sophisticated way to check convergence
|
||||||
|
if(timestamp > 300000){
|
||||||
|
System.out.println("converged after: " + currentEpisode + " episode!");
|
||||||
|
interruptLearning();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(){
|
public void learn(){
|
||||||
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
||||||
|
|
|
@ -12,7 +12,6 @@ import lombok.Setter;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.ObjectInputStream;
|
import java.io.ObjectInputStream;
|
||||||
import java.io.ObjectOutputStream;
|
import java.io.ObjectOutputStream;
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
@ -26,6 +25,13 @@ import java.util.concurrent.Executors;
|
||||||
*/
|
*/
|
||||||
@Getter
|
@Getter
|
||||||
public abstract class Learning<A extends Enum>{
|
public abstract class Learning<A extends Enum>{
|
||||||
|
// TODO: temp testing -> extract to dedicated test
|
||||||
|
protected int checkSum;
|
||||||
|
protected int rewardCheckSum;
|
||||||
|
|
||||||
|
// current discrete timestamp t
|
||||||
|
protected int timestamp;
|
||||||
|
protected int currentEpisode;
|
||||||
protected Policy<A> policy;
|
protected Policy<A> policy;
|
||||||
protected DiscreteActionSpace<A> actionSpace;
|
protected DiscreteActionSpace<A> actionSpace;
|
||||||
@Setter
|
@Setter
|
||||||
|
@ -83,6 +89,8 @@ public abstract class Learning<A extends Enum>{
|
||||||
|
|
||||||
protected void dispatchLearningEnd() {
|
protected void dispatchLearningEnd() {
|
||||||
currentlyLearning = false;
|
currentlyLearning = false;
|
||||||
|
System.out.println("Checksum: " + checkSum);
|
||||||
|
System.out.println("Reward Checksum: " + rewardCheckSum);
|
||||||
for (LearningListener l : learningListeners) {
|
for (LearningListener l : learningListeners) {
|
||||||
l.onLearningEnd();
|
l.onLearningEnd();
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,5 +5,5 @@ package core.algo;
|
||||||
* which RL-algorithm should be used.
|
* which RL-algorithm should be used.
|
||||||
*/
|
*/
|
||||||
public enum Method {
|
public enum Method {
|
||||||
MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ import java.util.*;
|
||||||
* For example:
|
* For example:
|
||||||
* <p>
|
* <p>
|
||||||
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
||||||
* image the agent does not collect the food and drops it to the start, the agent will receive
|
* image the agent does not collect the food and does not drop it onto start, the agent will receive
|
||||||
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
||||||
* <p>
|
* <p>
|
||||||
* BUT image moving left from the starting position will have no impact on the state because
|
* BUT image moving left from the starting position will have no impact on the state because
|
||||||
|
@ -30,12 +30,12 @@ import java.util.*;
|
||||||
*
|
*
|
||||||
* @param <A>
|
* @param <A>
|
||||||
*/
|
*/
|
||||||
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
|
||||||
private Map<Pair<State, A>, Double> returnSum;
|
private Map<Pair<State, A>, Double> returnSum;
|
||||||
private Map<Pair<State, A>, Integer> returnCount;
|
private Map<Pair<State, A>, Integer> returnCount;
|
||||||
|
|
||||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||||
|
@ -43,7 +43,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
||||||
returnCount = new HashMap<>();
|
returnCount = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,12 +58,16 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
||||||
}
|
}
|
||||||
sumOfRewards = 0;
|
sumOfRewards = 0;
|
||||||
StepResultEnvironment envResult = null;
|
StepResultEnvironment envResult = null;
|
||||||
|
//TODO extract to learning
|
||||||
|
int timestamp = 0;
|
||||||
while(envResult == null || !envResult.isDone()) {
|
while(envResult == null || !envResult.isDone()) {
|
||||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||||
A chosenAction = policy.chooseAction(actionValues);
|
A chosenAction = policy.chooseAction(actionValues);
|
||||||
|
checkSum += chosenAction.ordinal();
|
||||||
envResult = environment.step(chosenAction);
|
envResult = environment.step(chosenAction);
|
||||||
State nextState = envResult.getState();
|
State nextState = envResult.getState();
|
||||||
sumOfRewards += envResult.getReward();
|
sumOfRewards += envResult.getReward();
|
||||||
|
rewardCheckSum += envResult.getReward();
|
||||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||||
|
|
||||||
state = nextState;
|
state = nextState;
|
||||||
|
@ -73,6 +77,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
|
timestamp++;
|
||||||
dispatchStepEnd();
|
dispatchStepEnd();
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
||||||
actionValues = stateActionTable.getActionValues(nextState);
|
actionValues = stateActionTable.getActionValues(nextState);
|
||||||
A nextAction = policy.chooseAction(actionValues);
|
A nextAction = policy.chooseAction(actionValues);
|
||||||
|
|
||||||
// TD update
|
// td update
|
||||||
// target = reward + gamma * Q(nextState, nextAction)
|
// target = reward + gamma * Q(nextState, nextAction)
|
||||||
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||||
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);
|
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);
|
|
@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
|
||||||
import core.algo.EpisodicLearning;
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.algo.mc.MonteCarloControlEGreedy;
|
import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
|
||||||
import core.algo.td.QLearningOffPolicyTDControl;
|
import core.algo.td.QLearningOffPolicyTDControl;
|
||||||
import core.algo.td.SARSA;
|
import core.algo.td.SARSA;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
|
@ -48,8 +48,8 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
|
|
||||||
public void start() {
|
public void start() {
|
||||||
switch(method) {
|
switch(method) {
|
||||||
case MC_CONTROL_EGREEDY:
|
case MC_CONTROL_FIRST_VISIT:
|
||||||
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||||
break;
|
break;
|
||||||
case SARSA_EPISODIC:
|
case SARSA_EPISODIC:
|
||||||
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
||||||
|
@ -115,7 +115,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
try {
|
try {
|
||||||
fis = new FileInputStream(fileName);
|
fis = new FileInputStream(fileName);
|
||||||
in = new ObjectInputStream(fis);
|
in = new ObjectInputStream(fis);
|
||||||
System.out.println("interrup" + Thread.currentThread().getId());
|
System.out.println("interrupt" + Thread.currentThread().getId());
|
||||||
learning.interruptLearning();
|
learning.interruptLearning();
|
||||||
learning.load(in);
|
learning.load(in);
|
||||||
in.close();
|
in.close();
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
@Override
|
@Override
|
||||||
public StepResultEnvironment step(DinoAction action) {
|
public StepResultEnvironment step(DinoAction action) {
|
||||||
boolean done = false;
|
boolean done = false;
|
||||||
int reward = 0;
|
int reward = 1;
|
||||||
|
|
||||||
if(action == DinoAction.JUMP){
|
if(action == DinoAction.JUMP){
|
||||||
dino.jump();
|
dino.jump();
|
||||||
|
@ -74,7 +74,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
spawnNewObstacle();
|
spawnNewObstacle();
|
||||||
}
|
}
|
||||||
if(ranIntoObstacle()) {
|
if(ranIntoObstacle()) {
|
||||||
reward = -1;
|
reward = 0;
|
||||||
done = true;
|
done = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
package example;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
import core.algo.Method;
|
||||||
|
import core.controller.RLController;
|
||||||
|
import evironment.jumpingDino.DinoAction;
|
||||||
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
|
||||||
|
public class DinoSampling {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
for (int i = 0; i < 10 ; i++) {
|
||||||
|
RNG.setSeed(55);
|
||||||
|
|
||||||
|
RLController<DinoAction> rl = new RLController<>(
|
||||||
|
new DinoWorld(false, false),
|
||||||
|
Method.MC_CONTROL_FIRST_VISIT,
|
||||||
|
DinoAction.values());
|
||||||
|
|
||||||
|
rl.setDelay(0);
|
||||||
|
rl.setDiscountFactor(1f);
|
||||||
|
rl.setEpsilon(0.15f);
|
||||||
|
rl.setLearningRate(1f);
|
||||||
|
rl.setNrOfEpisodes(400);
|
||||||
|
rl.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,7 +3,6 @@ package example;
|
||||||
import core.RNG;
|
import core.RNG;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.controller.RLController;
|
import core.controller.RLController;
|
||||||
import core.controller.RLControllerGUI;
|
|
||||||
import evironment.jumpingDino.DinoAction;
|
import evironment.jumpingDino.DinoAction;
|
||||||
import evironment.jumpingDino.DinoWorld;
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
|
||||||
|
@ -11,16 +10,16 @@ public class JumpingDino {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(55);
|
RNG.setSeed(55);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
RLController<DinoAction> rl = new RLController<>(
|
||||||
new DinoWorld(false, false),
|
new DinoWorld(false, false),
|
||||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
Method.MC_CONTROL_FIRST_VISIT,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(1000);
|
rl.setDelay(0);
|
||||||
rl.setDiscountFactor(0.9f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(0.1f);
|
rl.setEpsilon(0.15f);
|
||||||
rl.setLearningRate(0.5f);
|
rl.setLearningRate(1f);
|
||||||
rl.setNrOfEpisodes(4000000);
|
rl.setNrOfEpisodes(400);
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ public class RunningAnt {
|
||||||
|
|
||||||
RLController<AntAction> rl = new RLControllerGUI<>(
|
RLController<AntAction> rl = new RLControllerGUI<>(
|
||||||
new AntWorld(3, 3, 0.1),
|
new AntWorld(3, 3, 0.1),
|
||||||
Method.MC_CONTROL_EGREEDY,
|
Method.MC_CONTROL_FIRST_VISIT,
|
||||||
AntAction.values());
|
AntAction.values());
|
||||||
|
|
||||||
rl.setDelay(200);
|
rl.setDelay(200);
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import core.RNG;
|
||||||
|
import core.algo.Method;
|
||||||
|
import core.controller.RLController;
|
||||||
|
import core.controller.RLControllerGUI;
|
||||||
|
import evironment.jumpingDino.DinoAction;
|
||||||
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
public class MCFirstVisit {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test if the action sequence is deterministic
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void deterministicActionSequence(){
|
||||||
|
RNG.setSeed(55);
|
||||||
|
|
||||||
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
|
new DinoWorld(false, false),
|
||||||
|
Method.MC_CONTROL_FIRST_VISIT,
|
||||||
|
DinoAction.values());
|
||||||
|
|
||||||
|
rl.setDelay(10);
|
||||||
|
rl.setDiscountFactor(1f);
|
||||||
|
rl.setEpsilon(0.1f);
|
||||||
|
rl.setLearningRate(0.8f);
|
||||||
|
rl.setNrOfEpisodes(4000000);
|
||||||
|
rl.start();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue