diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/DeterministicStateActionTable.java similarity index 74% rename from src/main/java/core/StateActionHashTable.java rename to src/main/java/core/DeterministicStateActionTable.java index dd981d0..6c9d55d 100644 --- a/src/main/java/core/StateActionHashTable.java +++ b/src/main/java/core/DeterministicStateActionTable.java @@ -1,20 +1,19 @@ package core; -import evironment.antGame.AntAction; - -import java.util.HashMap; +import java.io.Serializable; +import java.util.LinkedHashMap; import java.util.Map; /** * Premise: All states have the complete action space */ -public class StateActionHashTable implements StateActionTable { +public class DeterministicStateActionTable implements StateActionTable, Serializable { private final Map> table; private DiscreteActionSpace discreteActionSpace; - public StateActionHashTable(DiscreteActionSpace discreteActionSpace){ - table = new HashMap<>(); + public DeterministicStateActionTable(DiscreteActionSpace discreteActionSpace){ + table = new LinkedHashMap<>(); this.discreteActionSpace = discreteActionSpace; } @@ -61,19 +60,15 @@ public class StateActionHashTable implements StateActionTable return table.get(state); } - public static void main(String[] args) { - DiscreteActionSpace da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP); - StateActionTable sat = new StateActionHashTable<>(da); - State t = new State() { - }; - - System.out.println(sat.getActionValues(t)); - } private Map createDefaultActionValues(){ - final Map defaultActionValues = new HashMap<>(); + final Map defaultActionValues = new LinkedHashMap<>(); for(A action: discreteActionSpace.getAllActions()){ defaultActionValues.put(action, DEFAULT_VALUE); } return defaultActionValues; } + @Override + public int getStateCount(){ + return table.size(); + } } diff --git a/src/main/java/core/LearningConfig.java b/src/main/java/core/LearningConfig.java index 916de16..a4ecb68 100644 --- a/src/main/java/core/LearningConfig.java +++ b/src/main/java/core/LearningConfig.java @@ -1,7 +1,7 @@ package core; public class LearningConfig { - public static final int DEFAULT_DELAY = 1; + public static final int DEFAULT_DELAY = 30; public static final float DEFAULT_EPSILON = 0.1f; public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f; } diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java index 42de87a..3e1b862 100644 --- a/src/main/java/core/ListDiscreteActionSpace.java +++ b/src/main/java/core/ListDiscreteActionSpace.java @@ -1,10 +1,11 @@ package core; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -public class ListDiscreteActionSpace implements DiscreteActionSpace { +public class ListDiscreteActionSpace implements DiscreteActionSpace, Serializable { private List actions; public ListDiscreteActionSpace(){ diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java index 2fe8929..daf03b4 100644 --- a/src/main/java/core/RNG.java +++ b/src/main/java/core/RNG.java @@ -1,12 +1,14 @@ package core; +import java.security.SecureRandom; import java.util.Random; public class RNG { - private static Random rng; + private static SecureRandom rng; private static int seed = 123; static { - rng = new Random(seed); + rng = new SecureRandom(); + rng.setSeed(seed); } public static Random getRandom() { diff --git a/src/main/java/core/SaveState.java b/src/main/java/core/SaveState.java new file mode 100644 index 0000000..140ac7a --- /dev/null +++ b/src/main/java/core/SaveState.java @@ -0,0 +1,13 @@ +package core; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.io.Serializable; + +@AllArgsConstructor +@Getter +public class SaveState implements Serializable { + private StateActionTable stateActionTable; + private int currentEpisode; +} diff --git a/src/main/java/core/StateActionTable.java b/src/main/java/core/StateActionTable.java index 851ffba..3f863ba 100644 --- a/src/main/java/core/StateActionTable.java +++ b/src/main/java/core/StateActionTable.java @@ -7,6 +7,6 @@ public interface StateActionTable { double getValue(State state, A action); void setValue(State state, A action, double value); - + int getStateCount(); Map getActionValues(State state); } diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index f9aa580..9211224 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -3,13 +3,17 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; import core.listener.LearningListener; +import lombok.Getter; +import lombok.Setter; -public abstract class EpisodicLearning extends Learning implements Episodic{ +public abstract class EpisodicLearning extends Learning implements Episodic { + @Setter + @Getter protected int currentEpisode; protected int episodesToLearn; protected volatile int episodePerSecond; protected int episodeSumCurrentSecond; - private volatile boolean meseaureEpisodeBenchMark; + private volatile boolean measureEpisodeBenchMark; public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { super(environment, actionSpace, discountFactor, delay); @@ -29,6 +33,9 @@ public abstract class EpisodicLearning extends Learning imple protected void dispatchEpisodeEnd(double recentSumOfRewards){ ++episodeSumCurrentSecond; + if(rewardHistory.size() > 10000){ + rewardHistory.clear(); + } rewardHistory.add(recentSumOfRewards); for(LearningListener l: learningListeners) { l.onEpisodeEnd(rewardHistory); @@ -47,9 +54,9 @@ public abstract class EpisodicLearning extends Learning imple } public void learn(int nrOfEpisodes){ - meseaureEpisodeBenchMark = true; + measureEpisodeBenchMark = true; new Thread(()->{ - while(meseaureEpisodeBenchMark){ + while(measureEpisodeBenchMark){ episodePerSecond = episodeSumCurrentSecond; episodeSumCurrentSecond = 0; try { @@ -65,7 +72,7 @@ public abstract class EpisodicLearning extends Learning imple nextEpisode(); } dispatchLearningEnd(); - meseaureEpisodeBenchMark = false; + measureEpisodeBenchMark = false; } protected abstract void nextEpisode(); diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 1999f46..a9c73c7 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -9,15 +9,17 @@ import core.policy.Policy; import lombok.Getter; import lombok.Setter; +import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; @Getter -public abstract class Learning { +public abstract class Learning implements Serializable { protected Policy policy; protected DiscreteActionSpace actionSpace; + @Setter protected StateActionTable stateActionTable; protected Environment environment; protected float discountFactor; @@ -26,7 +28,7 @@ public abstract class Learning { protected int delay; protected List rewardHistory; - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay){ + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; @@ -35,39 +37,41 @@ public abstract class Learning { rewardHistory = new CopyOnWriteArrayList<>(); } - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor){ + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY); } - public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay){ + public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay); } - public Learning(Environment environment, DiscreteActionSpace actionSpace){ + public Learning(Environment environment, DiscreteActionSpace actionSpace) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } + public abstract void learn(); - public void addListener(LearningListener learningListener){ + public void addListener(LearningListener learningListener) { learningListeners.add(learningListener); } - protected void dispatchStepEnd(){ - for(LearningListener l: learningListeners){ + protected void dispatchStepEnd() { + for (LearningListener l : learningListeners) { l.onStepEnd(); } } - protected void dispatchLearningStart(){ - for(LearningListener l: learningListeners){ + protected void dispatchLearningStart() { + for (LearningListener l : learningListeners) { l.onLearningStart(); } } - protected void dispatchLearningEnd(){ - for(LearningListener l: learningListeners){ + protected void dispatchLearningEnd() { + for (LearningListener l : learningListeners) { l.onLearningEnd(); } } + } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index 544fd1b..6f468c0 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -35,7 +35,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< super(environment, actionSpace, discountFactor, delay); currentEpisode = 0; this.policy = new EpsilonGreedyPolicy<>(epsilon); - this.stateActionTable = new StateActionHashTable<>(this.actionSpace); + this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); returnSum = new HashMap<>(); returnCount = new HashMap<>(); } @@ -57,16 +57,15 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< e.printStackTrace(); } double sumOfRewards = 0; - for (int j = 0; j < 10; ++j) { + StepResultEnvironment envResult = null; + while(envResult == null || !envResult.isDone()){ Map actionValues = stateActionTable.getActionValues(state); A chosenAction = policy.chooseAction(actionValues); - StepResultEnvironment envResult = environment.step(chosenAction); + envResult = environment.step(chosenAction); State nextState = envResult.getState(); sumOfRewards += envResult.getReward(); episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); - if (envResult.isDone()) break; - state = nextState; try { @@ -78,13 +77,13 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< } dispatchEpisodeEnd(sumOfRewards); - System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); - Set> stateActionPairs = new HashSet<>(); + // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); + Set> stateActionPairs = new LinkedHashSet<>(); for (StepResult sr : episode) { stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); } - System.out.println("stateActionPairs " + stateActionPairs.size()); + //System.out.println("stateActionPairs " + stateActionPairs.size()); for (Pair stateActionPair : stateActionPairs) { int firstOccurenceIndex = 0; // find first occurance of state action pair diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 4a149e5..b733807 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -1,8 +1,6 @@ package core.controller; -import core.DiscreteActionSpace; -import core.Environment; -import core.ListDiscreteActionSpace; +import core.*; import core.algo.EpisodicLearning; import core.algo.Learning; import core.algo.Method; @@ -14,6 +12,7 @@ import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; +import java.io.*; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -23,10 +22,12 @@ public class RLController implements ViewListener, LearningListe protected Learning learning; protected DiscreteActionSpace discreteActionSpace; protected LearningView learningView; - private int delay; private int nrOfEpisodes; private Method method; private int prevDelay; + private int delay = LearningConfig.DEFAULT_DELAY; + private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; + private float epsilon = LearningConfig.DEFAULT_EPSILON; private boolean fastLearning; private boolean currentlyLearning; private ExecutorService learningExecutor; @@ -36,6 +37,7 @@ public class RLController implements ViewListener, LearningListe learningExecutor = Executors.newSingleThreadExecutor(); } + public void start(){ if(environment == null || discreteActionSpace == null || method == null){ throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()"); @@ -43,7 +45,7 @@ public class RLController implements ViewListener, LearningListe switch (method){ case MC_ONPOLICY_EGREEDY: - learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay); + learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; case TD_ONPOLICY: break; @@ -76,6 +78,44 @@ public class RLController implements ViewListener, LearningListe } } + @Override + public void onLoadState(String fileName) { + FileInputStream fis; + ObjectInput in; + try { + fis = new FileInputStream(fileName); + in = new ObjectInputStream(fis); + SaveState saveState = (SaveState) in.readObject(); + learning.setStateActionTable(saveState.getStateActionTable()); + if(learning instanceof EpisodicLearning){ + ((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode()); + } + in.close(); + } catch (IOException | ClassNotFoundException e) { + e.printStackTrace(); + } + } + + @Override + public void onSaveState(String fileName) { + FileOutputStream fos; + ObjectOutputStream out; + try{ + fos = new FileOutputStream(fileName); + out = new ObjectOutputStream(fos); + int currentEpisode; + if(learning instanceof EpisodicLearning){ + currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode(); + }else{ + currentEpisode = 0; + } + out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode)); + out.close(); + }catch (IOException e){ + e.printStackTrace(); + } + } + @Override public void onEpsilonChange(float epsilon) { if(learning.getPolicy() instanceof EpsilonPolicy){ @@ -169,4 +209,13 @@ public class RLController implements ViewListener, LearningListe this.nrOfEpisodes = nrOfEpisodes; return this; } + + public RLController setDiscountFactor(float discountFactor){ + this.discountFactor = discountFactor; + return this; + } + public RLController setEpsilon(float epsilon){ + this.epsilon = epsilon; + return this; + } } diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index d89cc97..5ad77db 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -11,6 +11,8 @@ import org.knowm.xchart.XYChart; import javax.swing.*; import java.awt.*; +import java.awt.event.ActionEvent; +import java.io.File; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; @@ -26,6 +28,8 @@ public class View implements LearningView{ private JFrame environmentFrame; private XChartPanel rewardChartPanel; private ViewListener viewListener; + private JMenuBar menuBar; + private JMenu fileMenu; public View(Learning learning, Environment environment, ViewListener viewListener) { this.learning = learning; @@ -38,7 +42,32 @@ public class View implements LearningView{ mainFrame = new JFrame(); mainFrame.setPreferredSize(new Dimension(1280, 720)); mainFrame.setLayout(new BorderLayout()); + menuBar = new JMenuBar(); + fileMenu = new JMenu("File"); + menuBar.add(fileMenu); + fileMenu.add(new JMenuItem(new AbstractAction("Load") { + @Override + public void actionPerformed(ActionEvent e) { + final JFileChooser fc = new JFileChooser(); + fc.setCurrentDirectory(new File(System.getProperty("user.dir"))); + int returnVal = fc.showOpenDialog(mainFrame); + if (returnVal == JFileChooser.APPROVE_OPTION) { + viewListener.onLoadState(fc.getSelectedFile().toString()); + } + } + })); + + fileMenu.add(new JMenuItem(new AbstractAction("Save") { + @Override + public void actionPerformed(ActionEvent e) { + String fileName = JOptionPane.showInputDialog("Enter file name", "save"); + if(fileName != null){ + viewListener.onSaveState(fileName); + } + } + })); + mainFrame.setJMenuBar(menuBar); initLearningInfoPanel(); initRewardChart(); diff --git a/src/main/java/core/listener/ViewListener.java b/src/main/java/core/listener/ViewListener.java index dbf01d4..abd0004 100644 --- a/src/main/java/core/listener/ViewListener.java +++ b/src/main/java/core/listener/ViewListener.java @@ -5,4 +5,6 @@ public interface ViewListener { void onDelayChange(int delay); void onFastLearnChange(boolean isFastLearn); void onLearnMoreEpisodes(int nrOfEpisodes); + void onLoadState(String fileName); + void onSaveState(String fileName); } diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java index 550889f..cfe506c 100644 --- a/src/main/java/core/policy/EpsilonGreedyPolicy.java +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -29,7 +29,8 @@ public class EpsilonGreedyPolicy implements EpsilonPolicy{ @Override public A chooseAction(Map actionValues) { - if(RNG.getRandom().nextFloat() < epsilon){ + float f = RNG.getRandom().nextFloat(); + if(f < epsilon){ // Take random action return randomPolicy.chooseAction(actionValues); }else{ diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java index 6ff7739..30d901f 100644 --- a/src/main/java/core/policy/GreedyPolicy.java +++ b/src/main/java/core/policy/GreedyPolicy.java @@ -1,9 +1,10 @@ package core.policy; +import core.RNG; + import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Random; public class GreedyPolicy implements Policy { @@ -26,6 +27,6 @@ public class GreedyPolicy implements Policy { } } - return equalHigh.get(new Random().nextInt(equalHigh.size())); + return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size())); } } diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java index 094b41c..ea99dd5 100644 --- a/src/main/java/core/policy/RandomPolicy.java +++ b/src/main/java/core/policy/RandomPolicy.java @@ -1,18 +1,17 @@ package core.policy; import core.RNG; + import java.util.Map; public class RandomPolicy implements Policy{ @Override public A chooseAction(Map actionValues) { int idx = RNG.getRandom().nextInt(actionValues.size()); - System.out.println("selected action " + idx); int i = 0; for(A action : actionValues.keySet()){ - if(i++ == idx) return action; + if(i++ == idx) return action; } - return null; } } diff --git a/src/main/java/evironment/jumpingDino/Config.java b/src/main/java/evironment/jumpingDino/Config.java new file mode 100644 index 0000000..29af520 --- /dev/null +++ b/src/main/java/evironment/jumpingDino/Config.java @@ -0,0 +1,13 @@ +package evironment.jumpingDino; + +public class Config { + public static final int FRAME_WIDTH = 1280; + public static final int FRAME_HEIGHT = 720; + public static final int GROUND_Y = 50; + public static final int DINO_STARTING_X = 50; + public static final int DINO_SIZE = 50; + public static final int OBSTACLE_SIZE = 60; + public static final int OBSTACLE_SPEED = 30; + public static final int DINO_JUMP_SPEED = 20; + public static final int MAX_JUMP_HEIGHT = 200; +} diff --git a/src/main/java/evironment/jumpingDino/Dino.java b/src/main/java/evironment/jumpingDino/Dino.java new file mode 100644 index 0000000..0c29ce3 --- /dev/null +++ b/src/main/java/evironment/jumpingDino/Dino.java @@ -0,0 +1,40 @@ +package evironment.jumpingDino; + +import lombok.Getter; + +import java.awt.*; + +public class Dino extends RenderObject { + @Getter + private boolean inJump; + + public Dino(int size, int x, int y, int dx, int dy, Color color) { + super(size, x, y, dx, dy, color); + } + + public void jump(){ + if(!inJump){ + dy = -Config.DINO_JUMP_SPEED; + inJump = true; + } + } + + private void fall(){ + if(inJump){ + dy = Config.DINO_JUMP_SPEED; + } + } + + @Override + public void tick(){ + // reached max jump height + if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){ + fall(); + }else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){ + inJump = false; + dy = 0; + y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE; + } + super.tick(); + } +} diff --git a/src/main/java/evironment/jumpingDino/DinoAction.java b/src/main/java/evironment/jumpingDino/DinoAction.java new file mode 100644 index 0000000..8adc627 --- /dev/null +++ b/src/main/java/evironment/jumpingDino/DinoAction.java @@ -0,0 +1,6 @@ +package evironment.jumpingDino; + +public enum DinoAction { + JUMP, + NOTHING, +} diff --git a/src/main/java/evironment/jumpingDino/DinoState.java b/src/main/java/evironment/jumpingDino/DinoState.java new file mode 100644 index 0000000..0f59dc1 --- /dev/null +++ b/src/main/java/evironment/jumpingDino/DinoState.java @@ -0,0 +1,32 @@ +package evironment.jumpingDino; + +import core.State; +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.io.Serializable; + +@AllArgsConstructor +@Getter +public class DinoState implements State, Serializable { + private int xDistanceToObstacle; + + @Override + public String toString() { + return Integer.toString(xDistanceToObstacle); + } + + @Override + public int hashCode() { + return this.xDistanceToObstacle; + } + + @Override + public boolean equals(Object obj) { + if(obj instanceof DinoState){ + DinoState toCompare = (DinoState) obj; + return toCompare.getXDistanceToObstacle() == this.xDistanceToObstacle; + } + return super.equals(obj); + } +} diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java new file mode 100644 index 0000000..3f40524 --- /dev/null +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -0,0 +1,78 @@ +package evironment.jumpingDino; + +import core.Environment; +import core.State; +import core.StepResultEnvironment; +import core.gui.Visualizable; +import evironment.jumpingDino.gui.DinoWorldComponent; +import lombok.Getter; + +import javax.swing.*; +import java.awt.*; + +@Getter +public class DinoWorld implements Environment, Visualizable { + private Dino dino; + private Obstacle currentObstacle; + + public DinoWorld(){ + dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); + spawnNewObstacle(); + } + + private boolean ranIntoObstacle(){ + Obstacle o = currentObstacle; + Dino p = dino; + + boolean xAxis = (o.getX() <= p.getX() && p.getX() < o.getX() + Config.OBSTACLE_SIZE) + || (o.getX() <= p.getX() + Config.DINO_SIZE && p.getX() + Config.DINO_SIZE < o.getX() + Config.OBSTACLE_SIZE); + + boolean yAxis = (o.getY() <= p.getY() && p.getY() < o.getY() + Config.OBSTACLE_SIZE) + || (o.getY() <= p.getY() + Config.DINO_SIZE && p.getY() + Config.DINO_SIZE < o.getY() + Config.OBSTACLE_SIZE); + + return xAxis && yAxis; + } + private int getDistanceToObstacle(){ + return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE; + } + + @Override + public StepResultEnvironment step(DinoAction action) { + boolean done = false; + int reward = 1; + + if(action == DinoAction.JUMP){ + dino.jump(); + } + + dino.tick(); + currentObstacle.tick(); + if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){ + spawnNewObstacle(); + } + + if(ranIntoObstacle()){ + done = true; + } + + return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, ""); + } + private void spawnNewObstacle(){ + currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK); + } + + private void spawnDino(){ + dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); + } + @Override + public State reset() { + spawnDino(); + spawnNewObstacle(); + return new DinoState(getDistanceToObstacle()); + } + + @Override + public JComponent visualize() { + return new DinoWorldComponent(this); + } +} diff --git a/src/main/java/evironment/jumpingDino/Obstacle.java b/src/main/java/evironment/jumpingDino/Obstacle.java new file mode 100644 index 0000000..0d01c4b --- /dev/null +++ b/src/main/java/evironment/jumpingDino/Obstacle.java @@ -0,0 +1,10 @@ +package evironment.jumpingDino; + +import java.awt.*; + +public class Obstacle extends RenderObject { + + public Obstacle(int size, int x, int y, int dx, int dy, Color color) { + super(size, x, y, dx, dy, color); + } +} diff --git a/src/main/java/evironment/jumpingDino/RenderObject.java b/src/main/java/evironment/jumpingDino/RenderObject.java new file mode 100644 index 0000000..3b1430b --- /dev/null +++ b/src/main/java/evironment/jumpingDino/RenderObject.java @@ -0,0 +1,28 @@ +package evironment.jumpingDino; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.awt.*; + + +@AllArgsConstructor +@Getter +public abstract class RenderObject { + protected int size; + protected int x; + protected int y; + protected int dx; + protected int dy; + protected Color color; + + public void render(Graphics g){ + g.setColor(color); + g.fillRect(x, y, size, size); + } + + public void tick(){ + y += dy; + x += dx; + } +} diff --git a/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java new file mode 100644 index 0000000..461934f --- /dev/null +++ b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java @@ -0,0 +1,27 @@ +package evironment.jumpingDino.gui; + +import evironment.jumpingDino.Config; +import evironment.jumpingDino.DinoWorld; + +import javax.swing.*; +import java.awt.*; + +public class DinoWorldComponent extends JComponent { + private DinoWorld dinoWorld; + + public DinoWorldComponent(DinoWorld dinoWorld){ + this.dinoWorld = dinoWorld; + setPreferredSize(new Dimension(Config.FRAME_WIDTH, Config.FRAME_HEIGHT)); + setVisible(true); + } + + @Override + protected void paintComponent(Graphics g) { + super.paintComponent(g); + g.setColor(Color.BLACK); + g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2); + + dinoWorld.getDino().render(g); + dinoWorld.getCurrentObstacle().render(g); + } +} diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java new file mode 100644 index 0000000..87b1ceb --- /dev/null +++ b/src/main/java/example/JumpingDino.java @@ -0,0 +1,23 @@ +package example; + +import core.RNG; +import core.algo.Method; +import core.controller.RLController; +import evironment.jumpingDino.DinoAction; +import evironment.jumpingDino.DinoWorld; + +public class JumpingDino { + public static void main(String[] args) { + RNG.setSeed(55); + + RLController rl = new RLController() + .setEnvironment(new DinoWorld()) + .setAllowedActions(DinoAction.values()) + .setMethod(Method.MC_ONPOLICY_EGREEDY) + .setDiscountFactor(1f) + .setEpsilon(0.15f) + .setDelay(200) + .setEpisodes(100000); + rl.start(); + } +}