diff --git a/src/main/java/core/gui/QTableFrame.java b/src/main/java/core/gui/QTableFrame.java index c820921..cb902a9 100644 --- a/src/main/java/core/gui/QTableFrame.java +++ b/src/main/java/core/gui/QTableFrame.java @@ -42,7 +42,6 @@ public class QTableFrame extends JFrame { } } protected void refreshQTable() { - System.out.println("ref"); int stateCount = stateActionTable.getStateCount(); stateCountLabel.setText("Total states: " + stateCount); int idx = -1; diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java index cfe506c..826159c 100644 --- a/src/main/java/core/policy/EpsilonGreedyPolicy.java +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -37,5 +37,6 @@ public class EpsilonGreedyPolicy implements EpsilonPolicy{ // Take the action with the highest value return greedyPolicy.chooseAction(actionValues); } + } } diff --git a/src/main/java/evironment/jumpingDino/DinoState.java b/src/main/java/evironment/jumpingDino/DinoState.java index 8f9a783..0a52a7a 100644 --- a/src/main/java/evironment/jumpingDino/DinoState.java +++ b/src/main/java/evironment/jumpingDino/DinoState.java @@ -14,12 +14,15 @@ import java.util.Objects; @Getter public class DinoState implements State, Serializable, Visualizable { private int xDistanceToObstacle; + private boolean isJumping; + protected final double scale = 0.5; @Override public String toString() { return "DinoState{" + "xDistanceToObstacle=" + xDistanceToObstacle + + "isJumping=" + isJumping + '}'; } @@ -28,12 +31,12 @@ public class DinoState implements State, Serializable, Visualizable { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; DinoState dinoState = (DinoState) o; - return xDistanceToObstacle == dinoState.xDistanceToObstacle; + return xDistanceToObstacle == dinoState.xDistanceToObstacle && isJumping == dinoState.isJumping; } @Override public int hashCode() { - return Objects.hash(xDistanceToObstacle); + return Objects.hash(xDistanceToObstacle, isJumping); } @Override diff --git a/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java b/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java index 4e3dd08..03466e3 100644 --- a/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java +++ b/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java @@ -10,8 +10,8 @@ import java.util.Objects; public class DinoStateWithSpeed extends DinoState implements Visualizable { private int obstacleSpeed; - public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) { - super(xDistanceToObstacle); + public DinoStateWithSpeed(int xDistanceToObstacle, boolean isJumping, int obstacleSpeed) { + super(xDistanceToObstacle, isJumping); this.obstacleSpeed = obstacleSpeed; } diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index 8e39889..953e2c4 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -50,7 +50,7 @@ public class DinoWorld implements Environment, Visualizable { @Override public StepResultEnvironment step(DinoAction action) { boolean done = false; - int reward = 1; + int reward = 0; if(action == DinoAction.JUMP){ dino.jump(); @@ -74,11 +74,11 @@ public class DinoWorld implements Environment, Visualizable { spawnNewObstacle(); } if(ranIntoObstacle()) { - reward = 0; + reward = -1; done = true; } - return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, ""); + return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()), reward, done, ""); } @@ -110,7 +110,7 @@ public class DinoWorld implements Environment, Visualizable { public State reset() { spawnDino(); spawnNewObstacle(); - return new DinoState(getDistanceToObstacle()); + return new DinoState(getDistanceToObstacle(), dino.isInJump()); } @Override diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index ff62761..c033316 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -16,13 +16,11 @@ public class JumpingDino { Method.Q_LEARNING_OFF_POLICY_CONTROL, DinoAction.values()); - rl.setDelay(10); - rl.setDiscountFactor(0.8f); + rl.setDelay(1000); + rl.setDiscountFactor(0.9f); rl.setEpsilon(0.1f); rl.setLearningRate(0.5f); - rl.setNrOfEpisodes(10000); + rl.setNrOfEpisodes(4000000); rl.start(); - - } }