diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/src/main/java/core/LearningConfig.java b/src/main/java/core/LearningConfig.java index a4ecb68..1933b05 100644 --- a/src/main/java/core/LearningConfig.java +++ b/src/main/java/core/LearningConfig.java @@ -2,6 +2,9 @@ package core; public class LearningConfig { public static final int DEFAULT_DELAY = 30; + public static final int DEFAULT_NR_OF_EPISODES = 10000; public static final float DEFAULT_EPSILON = 0.1f; public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f; + // Learning rate + public static final float DEFAULT_ALPHA = 0.9f; } diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 3a56363..4f997d3 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -2,6 +2,7 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; +import core.LearningConfig; import core.StepResult; import core.listener.LearningListener; import lombok.Getter; @@ -16,7 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger; public abstract class EpisodicLearning extends Learning implements Episodic { @Setter - protected int currentEpisode; + protected int currentEpisode = 0; protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0); @Getter protected volatile int episodePerSecond; @@ -81,7 +82,7 @@ public abstract class EpisodicLearning extends Learning imple @Override public void learn(){ - // TODO remove or learn with default episode number + learn(LearningConfig.DEFAULT_NR_OF_EPISODES); } private void startLearning(){ @@ -132,6 +133,15 @@ public abstract class EpisodicLearning extends Learning imple return episodesToLearn.get(); } + + public int getCurrentEpisode() { + return currentEpisode; + } + + public int getEpisodesPerSecond() { + return episodePerSecond; + } + @Override public synchronized void save(ObjectOutputStream oos) throws IOException { super.save(oos); diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloControlEGreedy.java similarity index 79% rename from src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java rename to src/main/java/core/algo/MC/MonteCarloControlEGreedy.java index c4ca4bc..d7cd028 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloControlEGreedy.java @@ -30,21 +30,20 @@ import java.util.*; * * @param */ -public class MonteCarloOnPolicyEGreedy extends EpisodicLearning { +public class MonteCarloControlEGreedy extends EpisodicLearning { private Map, Double> returnSum; private Map, Integer> returnCount; - public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); - currentEpisode = 0; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); returnSum = new HashMap<>(); returnCount = new HashMap<>(); } - public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { + public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } @@ -59,7 +58,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< } sumOfRewards = 0; StepResultEnvironment envResult = null; - while(envResult == null || !envResult.isDone()){ + while(envResult == null || !envResult.isDone()) { Map actionValues = stateActionTable.getActionValues(state); A chosenAction = policy.chooseAction(actionValues); envResult = environment.step(chosenAction); @@ -77,26 +76,26 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< dispatchStepEnd(); } - // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); + // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); Set> stateActionPairs = new LinkedHashSet<>(); - for (StepResult sr : episode) { + for(StepResult sr : episode) { stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction())); } //System.out.println("stateActionPairs " + stateActionPairs.size()); - for (Pair stateActionPair : stateActionPairs) { + for(Pair stateActionPair : stateActionPairs) { int firstOccurenceIndex = 0; // find first occurance of state action pair - for (StepResult sr : episode) { - if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { + for(StepResult sr : episode) { + if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { break; } firstOccurenceIndex++; } double G = 0; - for (int l = firstOccurenceIndex; l < episode.size(); ++l) { + for(int l = firstOccurenceIndex; l < episode.size(); ++l) { G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); } // slick trick to add G to the entry. @@ -107,16 +106,6 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< } } - @Override - public int getCurrentEpisode() { - return currentEpisode; - } - - @Override - public int getEpisodesPerSecond(){ - return episodePerSecond; - } - @Override public void save(ObjectOutputStream oos) throws IOException { super.save(oos); diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java index 3ac50cc..473a889 100644 --- a/src/main/java/core/algo/Method.java +++ b/src/main/java/core/algo/Method.java @@ -5,5 +5,5 @@ package core.algo; * which RL-algorithm should be used. */ public enum Method { - MC_ONPOLICY_EGREEDY, TD_ONPOLICY + MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL } diff --git a/src/main/java/core/algo/TD/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/TD/QLearningOffPolicyTDControl.java new file mode 100644 index 0000000..274e2ab --- /dev/null +++ b/src/main/java/core/algo/TD/QLearningOffPolicyTDControl.java @@ -0,0 +1,68 @@ +package core.algo.td; + +import core.*; +import core.algo.EpisodicLearning; +import core.policy.EpsilonGreedyPolicy; +import core.policy.GreedyPolicy; +import core.policy.Policy; + +import java.util.Map; + +public class QLearningOffPolicyTDControl extends EpisodicLearning { + private float alpha; + private Policy greedyPolicy = new GreedyPolicy<>(); + + public QLearningOffPolicyTDControl(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) { + super(environment, actionSpace, discountFactor, delay); + alpha = learningRate; + this.policy = new EpsilonGreedyPolicy<>(epsilon); + this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); + } + + public QLearningOffPolicyTDControl(Environment environment, DiscreteActionSpace actionSpace, int delay) { + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay); + } + + @Override + protected void nextEpisode() { + State state = environment.reset(); + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + StepResultEnvironment envResult = null; + Map actionValues = null; + + + sumOfRewards = 0; + while(envResult == null || !envResult.isDone()) { + actionValues = stateActionTable.getActionValues(state); + A action = policy.chooseAction(actionValues); + + // Take a step + envResult = environment.step(action); + double reward = envResult.getReward(); + State nextState = envResult.getState(); + sumOfRewards += reward; + + // Q Update + double currentQValue = stateActionTable.getActionValues(state).get(action); + // maxQ(S', a); + // Using intern "greedy policy" as a helper to determine the highest action-value + double highestValueNextState = stateActionTable.getActionValues(nextState).get(greedyPolicy.chooseAction(stateActionTable.getActionValues(nextState))); + + double updatedQValue = currentQValue + alpha * (reward + discountFactor * highestValueNextState - currentQValue); + stateActionTable.setValue(state, action, updatedQValue); + + state = nextState; + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } + dispatchStepEnd(); + } + } +} diff --git a/src/main/java/core/algo/TD/SARSA.java b/src/main/java/core/algo/TD/SARSA.java new file mode 100644 index 0000000..81d64fc --- /dev/null +++ b/src/main/java/core/algo/TD/SARSA.java @@ -0,0 +1,68 @@ +package core.algo.td; + +import core.*; +import core.algo.EpisodicLearning; +import core.policy.EpsilonGreedyPolicy; + +import java.util.Map; + + +public class SARSA extends EpisodicLearning { + private float alpha; + + public SARSA(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) { + super(environment, actionSpace, discountFactor, delay); + alpha = learningRate; + this.policy = new EpsilonGreedyPolicy<>(epsilon); + this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); + } + + public SARSA(Environment environment, DiscreteActionSpace actionSpace, int delay) { + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay); + } + + @Override + protected void nextEpisode() { + State state = environment.reset(); + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + StepResultEnvironment envResult = null; + Map actionValues = stateActionTable.getActionValues(state); + A action = policy.chooseAction(actionValues); + + sumOfRewards = 0; + while(envResult == null || !envResult.isDone()) { + // Take a step + envResult = environment.step(action); + sumOfRewards += envResult.getReward(); + + State nextState = envResult.getState(); + + // Pick next action + actionValues = stateActionTable.getActionValues(nextState); + A nextAction = policy.chooseAction(actionValues); + + // TD update + // target = reward + gamma * Q(nextState, nextAction) + double currentQValue = stateActionTable.getActionValues(state).get(action); + double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction); + double reward = envResult.getReward(); + double updatedQValue = currentQValue + alpha * (reward + discountFactor * nextQValue - currentQValue); + stateActionTable.setValue(state, action, updatedQValue); + + state = nextState; + action = nextAction; + + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } + dispatchStepEnd(); + } + } +} diff --git a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java deleted file mode 100644 index 33716f8..0000000 --- a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java +++ /dev/null @@ -1,4 +0,0 @@ -package core.algo.TD; - -public class TemporalDifferenceOnPolicy { -} diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 126ee2a..5bd3381 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -1,18 +1,19 @@ package core.controller; -import core.*; +import core.DiscreteActionSpace; +import core.Environment; +import core.LearningConfig; +import core.ListDiscreteActionSpace; import core.algo.EpisodicLearning; import core.algo.Learning; import core.algo.Method; -import core.algo.mc.MonteCarloOnPolicyEGreedy; -import core.gui.LearningView; -import core.gui.View; +import core.algo.mc.MonteCarloControlEGreedy; +import core.algo.td.QLearningOffPolicyTDControl; +import core.algo.td.SARSA; import core.listener.LearningListener; -import core.listener.ViewListener; import core.policy.EpsilonPolicy; import lombok.Setter; -import javax.swing.*; import java.io.*; import java.util.List; @@ -27,6 +28,8 @@ public class RLController implements LearningListener { @Setter protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; @Setter + protected float learningRate = LearningConfig.DEFAULT_DISCOUNT_FACTOR; + @Setter protected float epsilon = LearningConfig.DEFAULT_EPSILON; protected Learning learning; protected boolean fastLearning; @@ -45,10 +48,14 @@ public class RLController implements LearningListener { public void start() { switch(method) { - case MC_ONPOLICY_EGREEDY: - learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); + case MC_CONTROL_EGREEDY: + learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; - case TD_ONPOLICY: + case SARSA_EPISODIC: + learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); + break; + case Q_LEARNING_OFF_POLICY_CONTROL: + learning = new QLearningOffPolicyTDControl<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); break; default: throw new IllegalArgumentException("Undefined method"); diff --git a/src/main/java/core/gui/StateActionRow.java b/src/main/java/core/gui/StateActionRow.java index 53e8e36..76d238c 100644 --- a/src/main/java/core/gui/StateActionRow.java +++ b/src/main/java/core/gui/StateActionRow.java @@ -30,7 +30,6 @@ public class StateActionRow extends JTextArea { protected void refreshLabels(){ if(state == null || actionValues == null) return; - System.out.println("refreshing"); StringBuilder sb = new StringBuilder(state.toString()).append("\n"); for(Map.Entry actionValue: actionValues.entrySet()){ sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n"); diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index 368271b..ca917c9 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -29,7 +29,6 @@ public class AntState implements State, Visualizable { private int computeHash() { int hash = 7; int prime = 31; - int unknown = 0; int diff = 0; for (Cell[] cells : knownWorld) { diff --git a/src/main/java/evironment/jumpingDino/Dino.java b/src/main/java/evironment/jumpingDino/Dino.java index 0c29ce3..125fdfd 100644 --- a/src/main/java/evironment/jumpingDino/Dino.java +++ b/src/main/java/evironment/jumpingDino/Dino.java @@ -28,9 +28,11 @@ public class Dino extends RenderObject { @Override public void tick(){ // reached max jump height - if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){ + int topOfDino = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE; + + if(y + dy <= topOfDino - Config.MAX_JUMP_HEIGHT) { fall(); - }else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){ + } else if(y + dy >= topOfDino) { inJump = false; dy = 0; y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE; diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index 7792b14..8e39889 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -56,18 +56,28 @@ public class DinoWorld implements Environment, Visualizable { dino.jump(); } - for(int i= 0; i < 5; ++i){ - dino.tick(); - currentObstacle.tick(); - if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){ - spawnNewObstacle(); - } - comp.repaint(); - if(ranIntoObstacle()){ - done = true; - break; - } +// for(int i= 0; i < 5; ++i){ +// dino.tick(); +// currentObstacle.tick(); +// if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){ +// spawnNewObstacle(); +// } +// comp.repaint(); +// if(ranIntoObstacle()){ +// done = true; +// break; +// } +// } + dino.tick(); + currentObstacle.tick(); + if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) { + spawnNewObstacle(); } + if(ranIntoObstacle()) { + reward = 0; + done = true; + } + return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, ""); } diff --git a/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java index 461934f..238624c 100644 --- a/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java +++ b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java @@ -19,7 +19,7 @@ public class DinoWorldComponent extends JComponent { protected void paintComponent(Graphics g) { super.paintComponent(g); g.setColor(Color.BLACK); - g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2); + g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, getWidth(), 2); dinoWorld.getDino().render(g); dinoWorld.getCurrentObstacle().render(g); diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 84e3c7e..ff62761 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -12,15 +12,17 @@ public class JumpingDino { RNG.setSeed(55); RLController rl = new RLControllerGUI<>( - new DinoWorld(true, true), - Method.MC_ONPOLICY_EGREEDY, + new DinoWorld(false, false), + Method.Q_LEARNING_OFF_POLICY_CONTROL, DinoAction.values()); - rl.setDelay(100); - rl.setDiscountFactor(1f); - rl.setEpsilon(0.15f); - rl.setNrOfEpisodes(100000); - + rl.setDelay(10); + rl.setDiscountFactor(0.8f); + rl.setEpsilon(0.1f); + rl.setLearningRate(0.5f); + rl.setNrOfEpisodes(10000); rl.start(); + + } } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index ade0e92..9b83316 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -3,16 +3,17 @@ package example; import core.RNG; import core.algo.Method; import core.controller.RLController; +import core.controller.RLControllerGUI; import evironment.antGame.AntAction; import evironment.antGame.AntWorld; public class RunningAnt { public static void main(String[] args) { - RNG.setSeed(123); + RNG.setSeed(56); - RLController rl = new RLController<>( + RLController rl = new RLControllerGUI<>( new AntWorld(3, 3, 0.1), - Method.MC_ONPOLICY_EGREEDY, + Method.MC_CONTROL_EGREEDY, AntAction.values()); rl.setDelay(200);