diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 8bce333..534f8b4 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -16,12 +16,12 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; public abstract class EpisodicLearning extends Learning implements Episodic { + private volatile AtomicInteger episodesToLearn = new AtomicInteger(0); + private int episodeSumCurrentSecond; @Setter protected int currentEpisode = 0; - protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0); @Getter protected volatile int episodePerSecond; - protected int episodeSumCurrentSecond; protected double sumOfRewards; protected List> episode = new ArrayList<>(); @@ -84,7 +84,6 @@ public abstract class EpisodicLearning extends Learning imple protected void dispatchStepEnd() { super.dispatchStepEnd(); timestamp++; - timestampCurrentEpisode++; } @Override @@ -95,9 +94,7 @@ public abstract class EpisodicLearning extends Learning imple private void startLearning(){ dispatchLearningStart(); while(episodesToLearn.get() > 0){ - dispatchEpisodeStart(); - timestampCurrentEpisode = 0; nextEpisode(); dispatchEpisodeEnd(); } diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 20a9ec1..f6bb4ac 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -23,10 +23,6 @@ import java.util.concurrent.CopyOnWriteArrayList; */ @Getter public abstract class Learning{ - // TODO: temp testing -> extract to dedicated test - protected int checkSum; - protected int rewardCheckSum; - // current discrete timestamp t protected int timestamp; protected int currentEpisode; diff --git a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java index c0f78d8..beb9e4e 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java @@ -60,7 +60,6 @@ public class MonteCarloControlEGreedy extends EpisodicLearning(state, chosenAction, envResult.getReward())); state = nextState; @@ -74,8 +73,6 @@ public class MonteCarloControlEGreedy extends EpisodicLearning, List> stateActionPairs = new LinkedHashMap<>(); diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java index 2da1c60..fcc59fa 100644 --- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java +++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java @@ -5,18 +5,12 @@ import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; import core.policy.GreedyPolicy; import core.policy.Policy; -import evironment.antGame.Reward; -import example.ContinuousAnt; -import java.io.File; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Map; public class QLearningOffPolicyTDControl extends EpisodicLearning { private float alpha; + private Policy greedyPolicy = new GreedyPolicy<>(); public QLearningOffPolicyTDControl(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) { @@ -42,11 +36,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin StepResultEnvironment envResult = null; Map actionValues = null; - sumOfRewards = 0; - int timestampTilFood = 0; - int foodCollected = 0; - int foodTimestampsTotal= 0; while(envResult == null || !envResult.isDone()) { actionValues = stateActionTable.getActionValues(state); A action = policy.chooseAction(actionValues); @@ -56,44 +46,6 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin double reward = envResult.getReward(); State nextState = envResult.getState(); sumOfRewards += reward; - timestampTilFood++; - - if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) { - foodCollected++; - foodTimestampsTotal += timestampTilFood; - File file = new File(ContinuousAnt.FILE_NAME); - if(foodCollected % 1000 == 0) { - System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode); - try { - Files.writeString(Path.of(file.getPath()), foodTimestampsTotal / 1000f + ",", StandardOpenOption.APPEND); - } catch (IOException e) { - e.printStackTrace(); - } - foodTimestampsTotal = 0; - } - if(foodCollected == 1000){ - ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.15f); - } - if(foodCollected == 2000){ - ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.10f); - } - if(foodCollected == 3000){ - ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.05f); - } - if(foodCollected == 4000){ - System.out.println("Reached 0 exploration"); - ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.00f); - } - if(foodCollected == 15000){ - try { - Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND); - } catch (IOException e) { - e.printStackTrace(); - } - return; - } - timestampTilFood = 0; - } // Q Update double currentQValue = stateActionTable.getActionValues(state).get(action); diff --git a/src/main/java/core/algo/td/SARSA.java b/src/main/java/core/algo/td/SARSA.java index b64d59e..7ebeacf 100644 --- a/src/main/java/core/algo/td/SARSA.java +++ b/src/main/java/core/algo/td/SARSA.java @@ -3,8 +3,6 @@ package core.algo.td; import core.*; import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; -import core.policy.GreedyPolicy; -import core.policy.Policy; import java.util.Map; @@ -35,10 +33,8 @@ public class SARSA extends EpisodicLearning { StepResultEnvironment envResult = null; Map actionValues = stateActionTable.getActionValues(state); A action = policy.chooseAction(actionValues); - - //A action = policy.chooseAction(actionValues); - sumOfRewards = 0; + while(envResult == null || !envResult.isDone()) { // Take a step diff --git a/src/main/java/example/ContinuousAnt.java b/src/main/java/example/ContinuousAnt.java index 5c1df45..7c017a4 100644 --- a/src/main/java/example/ContinuousAnt.java +++ b/src/main/java/example/ContinuousAnt.java @@ -7,31 +7,18 @@ import core.controller.RLControllerGUI; import evironment.antGame.AntAction; import evironment.antGame.AntWorldContinuous; -import java.io.File; -import java.io.IOException; - public class ContinuousAnt { - public static final String FILE_NAME = "converge.txt"; - public static void main(String[] args) { - File file = new File(FILE_NAME); - try { - file.createNewFile(); - } catch (IOException e) { - e.printStackTrace(); - } RNG.setSeed(13, true); RLController rl = new RLControllerGUI<>( new AntWorldContinuous(8, 8), Method.Q_LEARNING_OFF_POLICY_CONTROL, AntAction.values()); - rl.setDelay(20); + rl.setDelay(200); rl.setNrOfEpisodes(1); - // 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99 - rl.setDiscountFactor(0.05f); - // 0.1, 0.3, 0.5, 0.7 0.9 + rl.setDiscountFactor(0.3f); rl.setLearningRate(0.9f); - rl.setEpsilon(0.2f); + rl.setEpsilon(0.15f); rl.start(); } } diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 0a03b16..d0f2a7b 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -3,20 +3,20 @@ package example; import core.RNG; import core.algo.Method; import core.controller.RLController; +import core.controller.RLControllerGUI; import evironment.jumpingDino.DinoAction; -import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorldAdvanced; public class JumpingDino { public static void main(String[] args) { RNG.setSeed(29); - RLController rl = new RLController<>( + RLController rl = new RLControllerGUI<>( new DinoWorldAdvanced(), Method.MC_CONTROL_FIRST_VISIT, DinoAction.values()); - rl.setDelay(0); + rl.setDelay(200); rl.setDiscountFactor(9f); rl.setEpsilon(0.05f); rl.setLearningRate(0.8f);