diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index 8bce333..534f8b4 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -16,12 +16,12 @@ import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning extends Learning implements Episodic {
+ private volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
+ private int episodeSumCurrentSecond;
@Setter
protected int currentEpisode = 0;
- protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
- protected int episodeSumCurrentSecond;
protected double sumOfRewards;
protected List> episode = new ArrayList<>();
@@ -84,7 +84,6 @@ public abstract class EpisodicLearning extends Learning imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
- timestampCurrentEpisode++;
}
@Override
@@ -95,9 +94,7 @@ public abstract class EpisodicLearning extends Learning imple
private void startLearning(){
dispatchLearningStart();
while(episodesToLearn.get() > 0){
-
dispatchEpisodeStart();
- timestampCurrentEpisode = 0;
nextEpisode();
dispatchEpisodeEnd();
}
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index 20a9ec1..f6bb4ac 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -23,10 +23,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
*/
@Getter
public abstract class Learning{
- // TODO: temp testing -> extract to dedicated test
- protected int checkSum;
- protected int rewardCheckSum;
-
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;
diff --git a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
index c0f78d8..beb9e4e 100644
--- a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
+++ b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
@@ -60,7 +60,6 @@ public class MonteCarloControlEGreedy extends EpisodicLearning(state, chosenAction, envResult.getReward()));
state = nextState;
@@ -74,8 +73,6 @@ public class MonteCarloControlEGreedy extends EpisodicLearning, List> stateActionPairs = new LinkedHashMap<>();
diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
index 2da1c60..fcc59fa 100644
--- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
+++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
@@ -5,18 +5,12 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
-import evironment.antGame.Reward;
-import example.ContinuousAnt;
-import java.io.File;
-import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.StandardOpenOption;
import java.util.Map;
public class QLearningOffPolicyTDControl extends EpisodicLearning {
private float alpha;
+
private Policy greedyPolicy = new GreedyPolicy<>();
public QLearningOffPolicyTDControl(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
@@ -42,11 +36,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
StepResultEnvironment envResult = null;
Map actionValues = null;
-
sumOfRewards = 0;
- int timestampTilFood = 0;
- int foodCollected = 0;
- int foodTimestampsTotal= 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
@@ -56,44 +46,6 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
- timestampTilFood++;
-
- if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
- foodCollected++;
- foodTimestampsTotal += timestampTilFood;
- File file = new File(ContinuousAnt.FILE_NAME);
- if(foodCollected % 1000 == 0) {
- System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
- try {
- Files.writeString(Path.of(file.getPath()), foodTimestampsTotal / 1000f + ",", StandardOpenOption.APPEND);
- } catch (IOException e) {
- e.printStackTrace();
- }
- foodTimestampsTotal = 0;
- }
- if(foodCollected == 1000){
- ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.15f);
- }
- if(foodCollected == 2000){
- ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.10f);
- }
- if(foodCollected == 3000){
- ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.05f);
- }
- if(foodCollected == 4000){
- System.out.println("Reached 0 exploration");
- ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.00f);
- }
- if(foodCollected == 15000){
- try {
- Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
- } catch (IOException e) {
- e.printStackTrace();
- }
- return;
- }
- timestampTilFood = 0;
- }
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);
diff --git a/src/main/java/core/algo/td/SARSA.java b/src/main/java/core/algo/td/SARSA.java
index b64d59e..7ebeacf 100644
--- a/src/main/java/core/algo/td/SARSA.java
+++ b/src/main/java/core/algo/td/SARSA.java
@@ -3,8 +3,6 @@ package core.algo.td;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
-import core.policy.GreedyPolicy;
-import core.policy.Policy;
import java.util.Map;
@@ -35,10 +33,8 @@ public class SARSA extends EpisodicLearning {
StepResultEnvironment envResult = null;
Map actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
-
- //A action = policy.chooseAction(actionValues);
-
sumOfRewards = 0;
+
while(envResult == null || !envResult.isDone()) {
// Take a step
diff --git a/src/main/java/example/ContinuousAnt.java b/src/main/java/example/ContinuousAnt.java
index 5c1df45..7c017a4 100644
--- a/src/main/java/example/ContinuousAnt.java
+++ b/src/main/java/example/ContinuousAnt.java
@@ -7,31 +7,18 @@ import core.controller.RLControllerGUI;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
-import java.io.File;
-import java.io.IOException;
-
public class ContinuousAnt {
- public static final String FILE_NAME = "converge.txt";
-
public static void main(String[] args) {
- File file = new File(FILE_NAME);
- try {
- file.createNewFile();
- } catch (IOException e) {
- e.printStackTrace();
- }
RNG.setSeed(13, true);
RLController rl = new RLControllerGUI<>(
new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
- rl.setDelay(20);
+ rl.setDelay(200);
rl.setNrOfEpisodes(1);
- // 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
- rl.setDiscountFactor(0.05f);
- // 0.1, 0.3, 0.5, 0.7 0.9
+ rl.setDiscountFactor(0.3f);
rl.setLearningRate(0.9f);
- rl.setEpsilon(0.2f);
+ rl.setEpsilon(0.15f);
rl.start();
}
}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index 0a03b16..d0f2a7b 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -3,20 +3,20 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
+import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
-import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced;
public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(29);
- RLController rl = new RLController<>(
+ RLController rl = new RLControllerGUI<>(
new DinoWorldAdvanced(),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
- rl.setDelay(0);
+ rl.setDelay(200);
rl.setDiscountFactor(9f);
rl.setEpsilon(0.05f);
rl.setLearningRate(0.8f);