From e67f40ad65e62f2bf84ce69af91ef5573d3fcd6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Thu, 5 Mar 2020 11:58:57 +0100 Subject: [PATCH] split DinoWorld between simple and advanced example # Conflicts: # src/main/java/example/JumpingDino.java --- .../evironment/jumpingDino/DinoWorld.java | 43 ++++++----------- .../jumpingDino/DinoWorldAdvanced.java | 26 ++++++++++ src/main/java/example/DinoSampling.java | 44 ++++++++++++----- src/main/java/example/JumpingDino.java | 47 +++++-------------- 4 files changed, 83 insertions(+), 77 deletions(-) create mode 100644 src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index 103fc08..9c18bd9 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -12,24 +12,15 @@ import java.awt.*; @Getter public class DinoWorld implements Environment, Visualizable { - private Dino dino; - private Obstacle currentObstacle; - private boolean randomObstacleSpeed; - private boolean randomObstacleDistance; + protected Dino dino; + protected Obstacle currentObstacle; private DinoWorldComponent comp; - public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){ - this.randomObstacleSpeed = randomObstacleSpeed; - this.randomObstacleDistance = randomObstacleDistance; + public DinoWorld(){ dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); spawnNewObstacle(); comp = new DinoWorldComponent(this); } - - public DinoWorld(){ - this(false, false); - } - private boolean ranIntoObstacle(){ Obstacle o = currentObstacle; Dino p = dino; @@ -43,7 +34,7 @@ public class DinoWorld implements Environment, Visualizable { return xAxis && yAxis; } - private int getDistanceToObstacle(){ + protected int getDistanceToObstacle(){ return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE; } @@ -78,27 +69,20 @@ public class DinoWorld implements Environment, Visualizable { done = true; } - return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()), reward, done, ""); + return new StepResultEnvironment(generateReturnState(), reward, done, ""); } + protected State generateReturnState(){ + return new DinoState(getDistanceToObstacle(), dino.isInJump()); + } - private void spawnNewObstacle(){ + protected void spawnNewObstacle(){ int dx; int xSpawn; - if(randomObstacleSpeed){ - dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED); - }else{ - dx = -Config.OBSTACLE_SPEED; - } - - if(randomObstacleDistance){ - // randomly spawning more right outside of the screen - xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE); - }else{ - // instantly respawning on the left screen border - xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE; - } + dx = -Config.OBSTACLE_SPEED; + // instantly respawning on the left screen border + xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE; currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK); } @@ -106,11 +90,12 @@ public class DinoWorld implements Environment, Visualizable { private void spawnDino(){ dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); } + @Override public State reset() { spawnDino(); spawnNewObstacle(); - return new DinoState(getDistanceToObstacle(), dino.isInJump()); + return generateReturnState(); } @Override diff --git a/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java b/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java new file mode 100644 index 0000000..7513ff0 --- /dev/null +++ b/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java @@ -0,0 +1,26 @@ +package evironment.jumpingDino; + +import core.State; + +import java.awt.*; + +public class DinoWorldAdvanced extends DinoWorld{ + public DinoWorldAdvanced(){ + super(); + } + + @Override + protected State generateReturnState() { + return new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()); + } + + @Override + protected void spawnNewObstacle() { + int dx; + int xSpawn; + dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED); + // randomly spawning more right outside of the screen + xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE); + currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK); + } +} diff --git a/src/main/java/example/DinoSampling.java b/src/main/java/example/DinoSampling.java index 09d5c5b..38c03ff 100644 --- a/src/main/java/example/DinoSampling.java +++ b/src/main/java/example/DinoSampling.java @@ -6,22 +6,42 @@ import core.controller.RLController; import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoWorld; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; + public class DinoSampling { public static void main(String[] args) { - for (int i = 0; i < 10 ; i++) { - RNG.setSeed(55); + File file = new File("convergence.txt"); + for(float f = 0.05f; f <=1.003 ; f+=0.05f){ + try { + Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } + for(int i = 1; i <= 100; i++) { + System.out.println("seed: " + i *13); + RNG.setSeed(i *13); - RLController rl = new RLController<>( - new DinoWorld(false, false), - Method.MC_CONTROL_FIRST_VISIT, - DinoAction.values()); + RLController rl = new RLController<>( + new DinoWorld(), + Method.MC_CONTROL_FIRST_VISIT, + DinoAction.values()); + rl.setDelay(0); + rl.setDiscountFactor(1f); + rl.setEpsilon(f); + rl.setLearningRate(1f); + rl.setNrOfEpisodes(20000); + rl.start(); - rl.setDelay(0); - rl.setDiscountFactor(1f); - rl.setEpsilon(0.15f); - rl.setLearningRate(1f); - rl.setNrOfEpisodes(400); - rl.start(); + } + try { + Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } } } } diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 13efb4e..3f71ca1 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -7,45 +7,20 @@ import core.controller.RLControllerGUI; import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoWorld; -import java.io.File; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; - public class JumpingDino { public static void main(String[] args) { - File file = new File("convergence.txt"); - for(float f = 0.05f; f <=1.003 ; f+=0.05f){ - try { - Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND); - } catch (IOException e) { - e.printStackTrace(); - } - for(int i = 1; i <= 100; i++) { - System.out.println("seed: " + i *13); - RNG.setSeed(i *13); + RNG.setSeed(55); - RLController rl = new RLController<>( - new DinoWorld(false, false), - Method.MC_CONTROL_FIRST_VISIT, - DinoAction.values()); + RLController rl = new RLControllerGUI<>( + new DinoWorld(), + Method.MC_CONTROL_FIRST_VISIT, + DinoAction.values()); - - rl.setDelay(0); - rl.setDiscountFactor(1f); - rl.setEpsilon(f); - rl.setLearningRate(1f); - rl.setNrOfEpisodes(20000); - rl.start(); - - } - try { - Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND); - } catch (IOException e) { - e.printStackTrace(); - } - } - System.out.println("kek"); + rl.setDelay(100); + rl.setDiscountFactor(1f); + rl.setEpsilon(0.15f); + rl.setLearningRate(1f); + rl.setNrOfEpisodes(10000); + rl.start(); } }