diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java
index 5c73a65..7ea9f4d 100644
--- a/src/main/java/core/ListDiscreteActionSpace.java
+++ b/src/main/java/core/ListDiscreteActionSpace.java
@@ -1,7 +1,10 @@
package core;
import java.io.Serializable;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
/**
* Implementation of a discrete action space.
@@ -18,6 +21,7 @@ public class ListDiscreteActionSpace implements DiscreteActionSp
actions = new ArrayList<>();
}
+ @SafeVarargs
public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions));
}
diff --git a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
index beb9e4e..835bbcf 100644
--- a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
+++ b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java
@@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy extends EpisodicLearning, Double> returnSum;
private Map, Integer> returnCount;
- private boolean isEveryVisit;
+ private final boolean isEveryVisit;
public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
index fcc59fa..896e775 100644
--- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
+++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
@@ -34,7 +34,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
}
StepResultEnvironment envResult = null;
- Map actionValues = null;
+ Map actionValues;
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 33bd168..b1e64ea 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -39,6 +39,7 @@ public class RLController implements LearningListener {
protected int prevDelay;
protected volatile boolean printNextEpisode;
+ @SafeVarargs
public RLController(Environment env, Method method, A... actions) {
setEnvironment(env);
setMethod(method);
@@ -102,9 +103,7 @@ public class RLController implements LearningListener {
if(learning.isCurrentlyLearning()){
((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
- new Thread(() -> {
- ((EpisodicLearning) learning).learn(nrOfEpisodes);
- }).start();
+ new Thread(() -> ((EpisodicLearning) learning).learn(nrOfEpisodes)).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
@@ -179,7 +178,7 @@ public class RLController implements LearningListener {
public void onEpisodeEnd(List rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
- System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
+ System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false;
}
diff --git a/src/main/java/core/controller/RLControllerGUI.java b/src/main/java/core/controller/RLControllerGUI.java
index be2ea56..adc8a14 100644
--- a/src/main/java/core/controller/RLControllerGUI.java
+++ b/src/main/java/core/controller/RLControllerGUI.java
@@ -13,6 +13,7 @@ import java.util.List;
public class RLControllerGUI extends RLController implements ViewListener {
private LearningView learningView;
+ @SafeVarargs
public RLControllerGUI(Environment env, Method method, A... actions) {
super(env, method, actions);
}
@@ -102,7 +103,7 @@ public class RLControllerGUI extends RLController implements
@Override
public void onLearningEnd() {
super.onLearningEnd();
- onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
+ onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
}
diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java
index ed2ebad..dfd8dbc 100644
--- a/src/main/java/core/gui/LearningInfoPanel.java
+++ b/src/main/java/core/gui/LearningInfoPanel.java
@@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesButton);
}
showQTableButton = new JButton("Show Q-Table");
- showQTableButton.addActionListener(e -> {
- viewListener.onShowQTable();
- });
+ showQTableButton.addActionListener(e -> viewListener.onShowQTable());
add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox);
add(last100Checkbox);
diff --git a/src/main/java/evironment/antGame/Ant.java b/src/main/java/evironment/antGame/Ant.java
index 4bd4599..d9a3cb6 100644
--- a/src/main/java/evironment/antGame/Ant.java
+++ b/src/main/java/evironment/antGame/Ant.java
@@ -1,7 +1,6 @@
package evironment.antGame;
import lombok.AccessLevel;
-import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java
index ca917c9..229b2de 100644
--- a/src/main/java/evironment/antGame/AntState.java
+++ b/src/main/java/evironment/antGame/AntState.java
@@ -86,12 +86,12 @@ public class AntState implements State, Visualizable {
public JComponent visualize() {
return new JScrollPane() {
private int cellSize;
- private final int paneWidth = 500;
- private final int paneHeight = 500;
private Font font;
{
+ int paneWidth = 500;
+ int paneHeight = 500;
setPreferredSize(new Dimension(paneWidth, paneHeight));
- cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
+ cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize);
JPanel worldPanel = new JPanel(){
{
diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java
index c9f9453..6f2dfe0 100644
--- a/src/main/java/evironment/antGame/AntWorld.java
+++ b/src/main/java/evironment/antGame/AntWorld.java
@@ -139,11 +139,9 @@ public class AntWorld implements Environment, Visualizable {
// valid movement
if(!sc.stayOnCell) {
myAnt.getPos().setLocation(sc.potentialNextPos);
- if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){
- // the ant will move to a cell that was previously unknown
- // TODO: not optimal for going straight for food
- // sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
- }
+ antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
+// TODO: not optimal for going straight for food
+// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
}
diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java
index c1d03f1..c6ef661 100644
--- a/src/main/java/evironment/antGame/Grid.java
+++ b/src/main/java/evironment/antGame/Grid.java
@@ -33,7 +33,6 @@ public class Grid {
spawnNewFood(initialGrid);
spawnObstacles();
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
- ;
}
diff --git a/src/main/java/evironment/antGame/gui/AntWorldComponent.java b/src/main/java/evironment/antGame/gui/AntWorldComponent.java
index 4b148b5..32e7755 100644
--- a/src/main/java/evironment/antGame/gui/AntWorldComponent.java
+++ b/src/main/java/evironment/antGame/gui/AntWorldComponent.java
@@ -7,10 +7,8 @@ import javax.swing.*;
import java.awt.*;
public class AntWorldComponent extends JComponent {
- private AntWorld antWorld;
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
- this.antWorld = antWorld;
setLayout(new BorderLayout());
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
diff --git a/src/main/java/evironment/jumpingDino/DinoState.java b/src/main/java/evironment/jumpingDino/DinoState.java
index 0a52a7a..3826456 100644
--- a/src/main/java/evironment/jumpingDino/DinoState.java
+++ b/src/main/java/evironment/jumpingDino/DinoState.java
@@ -2,21 +2,20 @@ package evironment.jumpingDino;
import core.State;
import core.gui.Visualizable;
-import lombok.AllArgsConstructor;
import lombok.Getter;
-import javax.swing.*;
import java.awt.*;
import java.io.Serializable;
import java.util.Objects;
-@AllArgsConstructor
@Getter
-public class DinoState implements State, Serializable, Visualizable {
- private int xDistanceToObstacle;
+public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
private boolean isJumping;
- protected final double scale = 0.5;
+ public DinoState(int xDistanceToObstacle, boolean isJumping) {
+ super(xDistanceToObstacle);
+ this.isJumping = isJumping;
+ }
@Override
public String toString() {
@@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable {
}
@Override
- public JComponent visualize() {
- return new JComponent() {
- {
- setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
- setVisible(true);
- }
-
- @Override
- protected void paintComponent(Graphics g) {
- super.paintComponents(g);
- drawObjects(g);
- }
- };
- }
-
- public void drawObjects(Graphics g){
- g.setColor(Color.BLACK);
- g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
-
- g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
- g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
-
- g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
-
+ protected void drawDinoInfo(Graphics g) {
+ int dinoY;
+ if(!isJumping) {
+ dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
+ g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
+ } else {
+ dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
+ g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
+ }
+ g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
}
}
diff --git a/src/main/java/evironment/jumpingDino/DinoStateSimple.java b/src/main/java/evironment/jumpingDino/DinoStateSimple.java
index 4552489..0c4dfa6 100644
--- a/src/main/java/evironment/jumpingDino/DinoStateSimple.java
+++ b/src/main/java/evironment/jumpingDino/DinoStateSimple.java
@@ -14,7 +14,7 @@ import java.util.Objects;
@Getter
public class DinoStateSimple implements State, Serializable, Visualizable {
protected final double scale = 0.5;
- private int xDistanceToObstacle;
+ protected int xDistanceToObstacle;
@Override
public String toString() {
@@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
public JComponent visualize() {
return new JComponent() {
{
- setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT)));
+ setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
};
}
+ protected void drawDinoInfo(Graphics g) {
+ g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
+ g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
+ }
+
public void drawObjects(Graphics g) {
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
-
- g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
- g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
-
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
-
+ drawDinoInfo(g);
}
}
diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java
index fdc66a4..b17d26c 100644
--- a/src/main/java/evironment/jumpingDino/DinoWorld.java
+++ b/src/main/java/evironment/jumpingDino/DinoWorld.java
@@ -62,9 +62,6 @@ public class DinoWorld implements Environment, Visualizable {
return new StepResultEnvironment(generateReturnState(), reward, done, "");
}
- protected State generateReturnState(){
- return new DinoStateSimple(getDistanceToObstacle());
- }
protected State generateReturnState(){
return new DinoState(getDistanceToObstacle(), dino.isInJump());
}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index d0f2a7b..e6ea1f5 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -13,13 +13,13 @@ public class JumpingDino {
RLController rl = new RLControllerGUI<>(
new DinoWorldAdvanced(),
- Method.MC_CONTROL_FIRST_VISIT,
+ Method.MC_CONTROL_EVERY_VISIT,
DinoAction.values());
rl.setDelay(200);
- rl.setDiscountFactor(9f);
+ rl.setDiscountFactor(1f);
rl.setEpsilon(0.05f);
- rl.setLearningRate(0.8f);
+ rl.setLearningRate(1f);
rl.setNrOfEpisodes(100000);
rl.start();
}