diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
similarity index 88%
rename from src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
rename to src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
index f00fa4b..c0f78d8 100644
--- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
+++ b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
@@ -12,19 +12,19 @@ import java.io.ObjectOutputStream;
import java.util.*;
/**
- * Includes both variants of Monte-Carlo methods
+ * Includes both! variants of Monte-Carlo methods
* Default method is First-Visit.
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
* @param
*/
-public class MonteCarloControlFirstVisitEGreedy extends EpisodicLearning {
+public class MonteCarloControlEGreedy extends EpisodicLearning {
private Map, Double> returnSum;
private Map, Integer> returnCount;
private boolean isEveryVisit;
- public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
+ public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
super(environment, actionSpace, discountFactor, delay);
isEveryVisit = useEveryVisit;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
@@ -33,11 +33,11 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
returnCount = new HashMap<>();
}
- public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
+ public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
this(environment, actionSpace, discountFactor, epsilon, delay, false);
}
- public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) {
+ public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
index fe0c5cf..2da1c60 100644
--- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
+++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
@@ -45,9 +45,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
sumOfRewards = 0;
int timestampTilFood = 0;
- int rewardsPer1000 = 0;
int foodCollected = 0;
- int iterations = 0;
int foodTimestampsTotal= 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
@@ -58,20 +56,8 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
- rewardsPer1000+=reward;
timestampTilFood++;
- /* if(iterations == 100){
- File file = new File(ContinuousAnt.FILE_NAME);
- try {
- Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
- } catch (IOException e) {
- e.printStackTrace();
- }
- return;
- }*/
-
-
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
foodCollected++;
foodTimestampsTotal += timestampTilFood;
@@ -95,7 +81,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
((EpsilonGreedyPolicy) this.policy).setEpsilon(0.05f);
}
if(foodCollected == 4000){
- System.out.println("final 0 expl");
+ System.out.println("Reached 0 exploration");
((EpsilonGreedyPolicy) this.policy).setEpsilon(0.00f);
}
if(foodCollected == 15000){
@@ -106,7 +92,6 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
}
return;
}
- iterations++;
timestampTilFood = 0;
}
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 895102a..e4df7e0 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
-import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
+import core.algo.mc.MonteCarloControlEGreedy;
import core.algo.td.QLearningOffPolicyTDControl;
import core.algo.td.SARSA;
import core.listener.LearningListener;
@@ -49,10 +49,10 @@ public class RLController implements LearningListener {
public void start() {
switch(method) {
case MC_CONTROL_FIRST_VISIT:
- learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
+ learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case MC_CONTROL_EVERY_VISIT:
- learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
+ learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
break;
case SARSA_ON_POLICY_CONTROL:
diff --git a/src/main/java/example/ContinuousAnt.java b/src/main/java/example/ContinuousAnt.java
index 7a49b1d..e9d7ace 100644
--- a/src/main/java/example/ContinuousAnt.java
+++ b/src/main/java/example/ContinuousAnt.java
@@ -11,30 +11,28 @@ import java.io.File;
import java.io.IOException;
public class ContinuousAnt {
- public static final String FILE_NAME = "converge22.txt";
+ public static final String FILE_NAME = "converge.txt";
+
public static void main(String[] args) {
- int i = 4+4+4+6+6+6+8+10+12+14+14+16+16+16+18+18+18+20+20+20+22+22+22+24+24+24+24+26+26+26+26+26+28+28+28+28+28+30+30+30+30+32+32+32+34+34+34+36+36+38+40+42;
- System.out.println(i/52f);
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
- RNG.setSeed(13);
+ RNG.setSeed(13);
RLController rl = new RLControllerGUI<>(
- new AntWorldContinuous(8, 8),
- Method.Q_LEARNING_OFF_POLICY_CONTROL,
- AntAction.values());
+ new AntWorldContinuous(8, 8),
+ Method.Q_LEARNING_OFF_POLICY_CONTROL,
+ AntAction.values());
rl.setDelay(20);
- rl.setNrOfEpisodes(1);
- //0.99 0.9 0.5
- //0.99 0.95 0.9 0.7 0.5 0.3 0.1
+ rl.setNrOfEpisodes(1);
+ // 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
rl.setDiscountFactor(0.05f);
- // 0.1, 0.3, 0.5, 0.7 0.9
- rl.setLearningRate(0.9f);
- rl.setEpsilon(0.2f);
- rl.start();
+ // 0.1, 0.3, 0.5, 0.7 0.9
+ rl.setLearningRate(0.9f);
+ rl.setEpsilon(0.2f);
+ rl.start();
}
diff --git a/src/main/java/example/DinoSampling.java b/src/main/java/example/DinoSampling.java
index c4bbde7..6c748ac 100644
--- a/src/main/java/example/DinoSampling.java
+++ b/src/main/java/example/DinoSampling.java
@@ -14,8 +14,8 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class DinoSampling {
- public static final float f =0.05f;
public static final String FILE_NAME = "converge.txt";
+
public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
@@ -23,15 +23,16 @@ public class DinoSampling {
} catch (IOException e) {
e.printStackTrace();
}
- for(float f = 0.05f; f <=1.003 ; f+=0.05f) {
+ for(float f = 0.05f; f <= 1.003; f += 0.05f) {
try {
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
- for (int i = 1; i <= 100; i++) {
- System.out.println("seed: " + i * 13);
- RNG.setSeed(i * 13);
+ for(int i = 1; i <= 100; i++) {
+ int seed = i * 13;
+ System.out.println("seed: " + seed);
+ RNG.setSeed(seed);
RLController rl = new RLControllerGUI<>(
new DinoWorld(),
diff --git a/src/main/java/example/Results b/src/main/java/example/Results
deleted file mode 100644
index ba42fc7..0000000
--- a/src/main/java/example/Results
+++ /dev/null
@@ -1,15 +0,0 @@
-Method:
-Epsilon = k / currentEpisode
-set to 0 if Epsilon < b
-
-k = 1.5
-b = 0.1 => conv. 16
-
-k = 1.5
-b = 0.02 => 75
-
-k = 1.4
-b = 0.02 => fail
-
-k = 2.0
-b = 0.02 => conv. 100
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
index d01461c..eddea17 100644
--- a/src/main/java/example/RunningAnt.java
+++ b/src/main/java/example/RunningAnt.java
@@ -19,8 +19,8 @@ public class RunningAnt {
rl.setDelay(200);
rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(0.9f);
+ rl.setLearningRate(0.9f);
rl.setEpsilon(0.15f);
-
rl.start();
}
}
diff --git a/src/main/java/example/Test.java b/src/main/java/example/Test.java
deleted file mode 100644
index c7ee691..0000000
--- a/src/main/java/example/Test.java
+++ /dev/null
@@ -1,52 +0,0 @@
-package example;
-
-public class Test {
- interface Drawable{
- void draw();
- }
- interface State{
- int getInt();
- }
-
- static class A implements Drawable, State{
- private int k;
- public A(int a){
- k = a;
- }
- @Override
- public void draw() {
- System.out.println("draw " + k);
- }
-
- @Override
- public int getInt() {
- System.out.println("getInt" + k);
- return k;
- }
- }
-
- static class B implements State{
- @Override
- public int getInt() {
- return 0;
- }
- }
-
- public static void main(String[] args) {
- State state = new A(24);
- State state2 = new B();
- state.getInt();
-
- System.out.println(state2 instanceof Drawable);
- drawState(state2);
- }
-
- static void drawState(State s){
- if(s instanceof Drawable){
- Drawable d = (Drawable) s;
- d.draw();
- }else{
- System.out.println("invalid");
- }
- }
-}