diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml
new file mode 100644
index 0000000..a55e7a1
--- /dev/null
+++ b/.idea/codeStyles/codeStyleConfig.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/src/main/java/core/LearningConfig.java b/src/main/java/core/LearningConfig.java
index a4ecb68..1933b05 100644
--- a/src/main/java/core/LearningConfig.java
+++ b/src/main/java/core/LearningConfig.java
@@ -2,6 +2,9 @@ package core;
public class LearningConfig {
public static final int DEFAULT_DELAY = 30;
+ public static final int DEFAULT_NR_OF_EPISODES = 10000;
public static final float DEFAULT_EPSILON = 0.1f;
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
+ // Learning rate
+ public static final float DEFAULT_ALPHA = 0.9f;
}
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index 3a56363..4f997d3 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -2,6 +2,7 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
+import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
import lombok.Getter;
@@ -16,7 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning extends Learning implements Episodic {
@Setter
- protected int currentEpisode;
+ protected int currentEpisode = 0;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
@@ -81,7 +82,7 @@ public abstract class EpisodicLearning extends Learning imple
@Override
public void learn(){
- // TODO remove or learn with default episode number
+ learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
}
private void startLearning(){
@@ -132,6 +133,15 @@ public abstract class EpisodicLearning extends Learning imple
return episodesToLearn.get();
}
+
+ public int getCurrentEpisode() {
+ return currentEpisode;
+ }
+
+ public int getEpisodesPerSecond() {
+ return episodePerSecond;
+ }
+
@Override
public synchronized void save(ObjectOutputStream oos) throws IOException {
super.save(oos);
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloControlEGreedy.java
similarity index 79%
rename from src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
rename to src/main/java/core/algo/MC/MonteCarloControlEGreedy.java
index c4ca4bc..d7cd028 100644
--- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
+++ b/src/main/java/core/algo/MC/MonteCarloControlEGreedy.java
@@ -30,21 +30,20 @@ import java.util.*;
*
* @param
*/
-public class MonteCarloOnPolicyEGreedy extends EpisodicLearning {
+public class MonteCarloControlEGreedy extends EpisodicLearning {
private Map, Double> returnSum;
private Map, Integer> returnCount;
- public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
+ public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
- currentEpisode = 0;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
returnCount = new HashMap<>();
}
- public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) {
+ public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@@ -59,7 +58,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
}
sumOfRewards = 0;
StepResultEnvironment envResult = null;
- while(envResult == null || !envResult.isDone()){
+ while(envResult == null || !envResult.isDone()) {
Map actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
envResult = environment.step(chosenAction);
@@ -77,26 +76,26 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
dispatchStepEnd();
}
- // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
+ // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set> stateActionPairs = new LinkedHashSet<>();
- for (StepResult sr : episode) {
+ for(StepResult sr : episode) {
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
}
//System.out.println("stateActionPairs " + stateActionPairs.size());
- for (Pair stateActionPair : stateActionPairs) {
+ for(Pair stateActionPair : stateActionPairs) {
int firstOccurenceIndex = 0;
// find first occurance of state action pair
- for (StepResult sr : episode) {
- if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
+ for(StepResult sr : episode) {
+ if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
break;
}
firstOccurenceIndex++;
}
double G = 0;
- for (int l = firstOccurenceIndex; l < episode.size(); ++l) {
+ for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
}
// slick trick to add G to the entry.
@@ -107,16 +106,6 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
}
}
- @Override
- public int getCurrentEpisode() {
- return currentEpisode;
- }
-
- @Override
- public int getEpisodesPerSecond(){
- return episodePerSecond;
- }
-
@Override
public void save(ObjectOutputStream oos) throws IOException {
super.save(oos);
diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java
index 3ac50cc..473a889 100644
--- a/src/main/java/core/algo/Method.java
+++ b/src/main/java/core/algo/Method.java
@@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used.
*/
public enum Method {
- MC_ONPOLICY_EGREEDY, TD_ONPOLICY
+ MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
}
diff --git a/src/main/java/core/algo/TD/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/TD/QLearningOffPolicyTDControl.java
new file mode 100644
index 0000000..274e2ab
--- /dev/null
+++ b/src/main/java/core/algo/TD/QLearningOffPolicyTDControl.java
@@ -0,0 +1,68 @@
+package core.algo.td;
+
+import core.*;
+import core.algo.EpisodicLearning;
+import core.policy.EpsilonGreedyPolicy;
+import core.policy.GreedyPolicy;
+import core.policy.Policy;
+
+import java.util.Map;
+
+public class QLearningOffPolicyTDControl extends EpisodicLearning {
+ private float alpha;
+ private Policy greedyPolicy = new GreedyPolicy<>();
+
+ public QLearningOffPolicyTDControl(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
+ super(environment, actionSpace, discountFactor, delay);
+ alpha = learningRate;
+ this.policy = new EpsilonGreedyPolicy<>(epsilon);
+ this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
+ }
+
+ public QLearningOffPolicyTDControl(Environment environment, DiscreteActionSpace actionSpace, int delay) {
+ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay);
+ }
+
+ @Override
+ protected void nextEpisode() {
+ State state = environment.reset();
+ try {
+ Thread.sleep(delay);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+
+ StepResultEnvironment envResult = null;
+ Map actionValues = null;
+
+
+ sumOfRewards = 0;
+ while(envResult == null || !envResult.isDone()) {
+ actionValues = stateActionTable.getActionValues(state);
+ A action = policy.chooseAction(actionValues);
+
+ // Take a step
+ envResult = environment.step(action);
+ double reward = envResult.getReward();
+ State nextState = envResult.getState();
+ sumOfRewards += reward;
+
+ // Q Update
+ double currentQValue = stateActionTable.getActionValues(state).get(action);
+ // maxQ(S', a);
+ // Using intern "greedy policy" as a helper to determine the highest action-value
+ double highestValueNextState = stateActionTable.getActionValues(nextState).get(greedyPolicy.chooseAction(stateActionTable.getActionValues(nextState)));
+
+ double updatedQValue = currentQValue + alpha * (reward + discountFactor * highestValueNextState - currentQValue);
+ stateActionTable.setValue(state, action, updatedQValue);
+
+ state = nextState;
+ try {
+ Thread.sleep(delay);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ dispatchStepEnd();
+ }
+ }
+}
diff --git a/src/main/java/core/algo/TD/SARSA.java b/src/main/java/core/algo/TD/SARSA.java
new file mode 100644
index 0000000..81d64fc
--- /dev/null
+++ b/src/main/java/core/algo/TD/SARSA.java
@@ -0,0 +1,68 @@
+package core.algo.td;
+
+import core.*;
+import core.algo.EpisodicLearning;
+import core.policy.EpsilonGreedyPolicy;
+
+import java.util.Map;
+
+
+public class SARSA extends EpisodicLearning {
+ private float alpha;
+
+ public SARSA(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
+ super(environment, actionSpace, discountFactor, delay);
+ alpha = learningRate;
+ this.policy = new EpsilonGreedyPolicy<>(epsilon);
+ this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
+ }
+
+ public SARSA(Environment environment, DiscreteActionSpace actionSpace, int delay) {
+ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay);
+ }
+
+ @Override
+ protected void nextEpisode() {
+ State state = environment.reset();
+ try {
+ Thread.sleep(delay);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+
+ StepResultEnvironment envResult = null;
+ Map actionValues = stateActionTable.getActionValues(state);
+ A action = policy.chooseAction(actionValues);
+
+ sumOfRewards = 0;
+ while(envResult == null || !envResult.isDone()) {
+ // Take a step
+ envResult = environment.step(action);
+ sumOfRewards += envResult.getReward();
+
+ State nextState = envResult.getState();
+
+ // Pick next action
+ actionValues = stateActionTable.getActionValues(nextState);
+ A nextAction = policy.chooseAction(actionValues);
+
+ // TD update
+ // target = reward + gamma * Q(nextState, nextAction)
+ double currentQValue = stateActionTable.getActionValues(state).get(action);
+ double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);
+ double reward = envResult.getReward();
+ double updatedQValue = currentQValue + alpha * (reward + discountFactor * nextQValue - currentQValue);
+ stateActionTable.setValue(state, action, updatedQValue);
+
+ state = nextState;
+ action = nextAction;
+
+ try {
+ Thread.sleep(delay);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ dispatchStepEnd();
+ }
+ }
+}
diff --git a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java
deleted file mode 100644
index 33716f8..0000000
--- a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java
+++ /dev/null
@@ -1,4 +0,0 @@
-package core.algo.TD;
-
-public class TemporalDifferenceOnPolicy {
-}
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 126ee2a..5bd3381 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -1,18 +1,19 @@
package core.controller;
-import core.*;
+import core.DiscreteActionSpace;
+import core.Environment;
+import core.LearningConfig;
+import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
-import core.algo.mc.MonteCarloOnPolicyEGreedy;
-import core.gui.LearningView;
-import core.gui.View;
+import core.algo.mc.MonteCarloControlEGreedy;
+import core.algo.td.QLearningOffPolicyTDControl;
+import core.algo.td.SARSA;
import core.listener.LearningListener;
-import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import lombok.Setter;
-import javax.swing.*;
import java.io.*;
import java.util.List;
@@ -27,6 +28,8 @@ public class RLController implements LearningListener {
@Setter
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
@Setter
+ protected float learningRate = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
+ @Setter
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
protected Learning learning;
protected boolean fastLearning;
@@ -45,10 +48,14 @@ public class RLController implements LearningListener {
public void start() {
switch(method) {
- case MC_ONPOLICY_EGREEDY:
- learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
+ case MC_CONTROL_EGREEDY:
+ learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
- case TD_ONPOLICY:
+ case SARSA_EPISODIC:
+ learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
+ break;
+ case Q_LEARNING_OFF_POLICY_CONTROL:
+ learning = new QLearningOffPolicyTDControl<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
default:
throw new IllegalArgumentException("Undefined method");
diff --git a/src/main/java/core/gui/StateActionRow.java b/src/main/java/core/gui/StateActionRow.java
index 53e8e36..76d238c 100644
--- a/src/main/java/core/gui/StateActionRow.java
+++ b/src/main/java/core/gui/StateActionRow.java
@@ -30,7 +30,6 @@ public class StateActionRow extends JTextArea {
protected void refreshLabels(){
if(state == null || actionValues == null) return;
- System.out.println("refreshing");
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
for(Map.Entry actionValue: actionValues.entrySet()){
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java
index 368271b..ca917c9 100644
--- a/src/main/java/evironment/antGame/AntState.java
+++ b/src/main/java/evironment/antGame/AntState.java
@@ -29,7 +29,6 @@ public class AntState implements State, Visualizable {
private int computeHash() {
int hash = 7;
int prime = 31;
-
int unknown = 0;
int diff = 0;
for (Cell[] cells : knownWorld) {
diff --git a/src/main/java/evironment/jumpingDino/Dino.java b/src/main/java/evironment/jumpingDino/Dino.java
index 0c29ce3..125fdfd 100644
--- a/src/main/java/evironment/jumpingDino/Dino.java
+++ b/src/main/java/evironment/jumpingDino/Dino.java
@@ -28,9 +28,11 @@ public class Dino extends RenderObject {
@Override
public void tick(){
// reached max jump height
- if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){
+ int topOfDino = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
+
+ if(y + dy <= topOfDino - Config.MAX_JUMP_HEIGHT) {
fall();
- }else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){
+ } else if(y + dy >= topOfDino) {
inJump = false;
dy = 0;
y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java
index 7792b14..8e39889 100644
--- a/src/main/java/evironment/jumpingDino/DinoWorld.java
+++ b/src/main/java/evironment/jumpingDino/DinoWorld.java
@@ -56,18 +56,28 @@ public class DinoWorld implements Environment, Visualizable {
dino.jump();
}
- for(int i= 0; i < 5; ++i){
- dino.tick();
- currentObstacle.tick();
- if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
- spawnNewObstacle();
- }
- comp.repaint();
- if(ranIntoObstacle()){
- done = true;
- break;
- }
+// for(int i= 0; i < 5; ++i){
+// dino.tick();
+// currentObstacle.tick();
+// if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
+// spawnNewObstacle();
+// }
+// comp.repaint();
+// if(ranIntoObstacle()){
+// done = true;
+// break;
+// }
+// }
+ dino.tick();
+ currentObstacle.tick();
+ if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) {
+ spawnNewObstacle();
}
+ if(ranIntoObstacle()) {
+ reward = 0;
+ done = true;
+ }
+
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
}
diff --git a/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java
index 461934f..238624c 100644
--- a/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java
+++ b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java
@@ -19,7 +19,7 @@ public class DinoWorldComponent extends JComponent {
protected void paintComponent(Graphics g) {
super.paintComponent(g);
g.setColor(Color.BLACK);
- g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2);
+ g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, getWidth(), 2);
dinoWorld.getDino().render(g);
dinoWorld.getCurrentObstacle().render(g);
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index 84e3c7e..ff62761 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -12,15 +12,17 @@ public class JumpingDino {
RNG.setSeed(55);
RLController rl = new RLControllerGUI<>(
- new DinoWorld(true, true),
- Method.MC_ONPOLICY_EGREEDY,
+ new DinoWorld(false, false),
+ Method.Q_LEARNING_OFF_POLICY_CONTROL,
DinoAction.values());
- rl.setDelay(100);
- rl.setDiscountFactor(1f);
- rl.setEpsilon(0.15f);
- rl.setNrOfEpisodes(100000);
-
+ rl.setDelay(10);
+ rl.setDiscountFactor(0.8f);
+ rl.setEpsilon(0.1f);
+ rl.setLearningRate(0.5f);
+ rl.setNrOfEpisodes(10000);
rl.start();
+
+
}
}
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
index ade0e92..9b83316 100644
--- a/src/main/java/example/RunningAnt.java
+++ b/src/main/java/example/RunningAnt.java
@@ -3,16 +3,17 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
+import core.controller.RLControllerGUI;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
- RNG.setSeed(123);
+ RNG.setSeed(56);
- RLController rl = new RLController<>(
+ RLController rl = new RLControllerGUI<>(
new AntWorld(3, 3, 0.1),
- Method.MC_ONPOLICY_EGREEDY,
+ Method.MC_CONTROL_EGREEDY,
AntAction.values());
rl.setDelay(200);