From 9d1f8dfd46343efa70551f39fb8f177a6d1e03fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sun, 5 Apr 2020 14:44:48 +0200 Subject: [PATCH] apply code improvements suggested by intelliJ --- .../java/core/ListDiscreteActionSpace.java | 6 ++- .../algo/mc/MonteCarloControlEGreedy.java | 2 +- .../algo/td/QLearningOffPolicyTDControl.java | 2 +- .../java/core/controller/RLController.java | 7 ++- .../java/core/controller/RLControllerGUI.java | 3 +- src/main/java/core/gui/LearningInfoPanel.java | 4 +- src/main/java/evironment/antGame/Ant.java | 1 - .../java/evironment/antGame/AntState.java | 6 +-- .../java/evironment/antGame/AntWorld.java | 8 ++-- src/main/java/evironment/antGame/Grid.java | 1 - .../antGame/gui/AntWorldComponent.java | 2 - .../evironment/jumpingDino/DinoState.java | 45 +++++++------------ .../jumpingDino/DinoStateSimple.java | 15 ++++--- .../evironment/jumpingDino/DinoWorld.java | 3 -- src/main/java/example/JumpingDino.java | 6 +-- 15 files changed, 45 insertions(+), 66 deletions(-) diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java index 5c73a65..7ea9f4d 100644 --- a/src/main/java/core/ListDiscreteActionSpace.java +++ b/src/main/java/core/ListDiscreteActionSpace.java @@ -1,7 +1,10 @@ package core; import java.io.Serializable; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; /** * Implementation of a discrete action space. @@ -18,6 +21,7 @@ public class ListDiscreteActionSpace implements DiscreteActionSp actions = new ArrayList<>(); } + @SafeVarargs public ListDiscreteActionSpace(A... actions){ this.actions = new ArrayList<>(Arrays.asList(actions)); } diff --git a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java index beb9e4e..835bbcf 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java @@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy extends EpisodicLearning, Double> returnSum; private Map, Integer> returnCount; - private boolean isEveryVisit; + private final boolean isEveryVisit; public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) { diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java index fcc59fa..896e775 100644 --- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java +++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java @@ -34,7 +34,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin } StepResultEnvironment envResult = null; - Map actionValues = null; + Map actionValues; sumOfRewards = 0; while(envResult == null || !envResult.isDone()) { diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 33bd168..b1e64ea 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -39,6 +39,7 @@ public class RLController implements LearningListener { protected int prevDelay; protected volatile boolean printNextEpisode; + @SafeVarargs public RLController(Environment env, Method method, A... actions) { setEnvironment(env); setMethod(method); @@ -102,9 +103,7 @@ public class RLController implements LearningListener { if(learning.isCurrentlyLearning()){ ((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes); }else{ - new Thread(() -> { - ((EpisodicLearning) learning).learn(nrOfEpisodes); - }).start(); + new Thread(() -> ((EpisodicLearning) learning).learn(nrOfEpisodes)).start(); } } else { throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); @@ -179,7 +178,7 @@ public class RLController implements LearningListener { public void onEpisodeEnd(List rewardHistory) { latestRewardsHistory = rewardHistory; if(printNextEpisode) { - System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); + System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); printNextEpisode = false; } diff --git a/src/main/java/core/controller/RLControllerGUI.java b/src/main/java/core/controller/RLControllerGUI.java index be2ea56..adc8a14 100644 --- a/src/main/java/core/controller/RLControllerGUI.java +++ b/src/main/java/core/controller/RLControllerGUI.java @@ -13,6 +13,7 @@ import java.util.List; public class RLControllerGUI extends RLController implements ViewListener { private LearningView learningView; + @SafeVarargs public RLControllerGUI(Environment env, Method method, A... actions) { super(env, method, actions); } @@ -102,7 +103,7 @@ public class RLControllerGUI extends RLController implements @Override public void onLearningEnd() { super.onLearningEnd(); - onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : "")); + onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : "")); SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory)); } } diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index ed2ebad..dfd8dbc 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel { add(learnMoreEpisodesButton); } showQTableButton = new JButton("Show Q-Table"); - showQTableButton.addActionListener(e -> { - viewListener.onShowQTable(); - }); + showQTableButton.addActionListener(e -> viewListener.onShowQTable()); add(drawEnvironmentCheckbox); add(smoothGraphCheckbox); add(last100Checkbox); diff --git a/src/main/java/evironment/antGame/Ant.java b/src/main/java/evironment/antGame/Ant.java index 4bd4599..d9a3cb6 100644 --- a/src/main/java/evironment/antGame/Ant.java +++ b/src/main/java/evironment/antGame/Ant.java @@ -1,7 +1,6 @@ package evironment.antGame; import lombok.AccessLevel; -import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index ca917c9..229b2de 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -86,12 +86,12 @@ public class AntState implements State, Visualizable { public JComponent visualize() { return new JScrollPane() { private int cellSize; - private final int paneWidth = 500; - private final int paneHeight = 500; private Font font; { + int paneWidth = 500; + int paneHeight = 500; setPreferredSize(new Dimension(paneWidth, paneHeight)); - cellSize = (paneWidth- knownWorld.length) /knownWorld.length; + cellSize = (paneWidth - knownWorld.length) / knownWorld.length; font = new Font("plain", Font.BOLD, cellSize); JPanel worldPanel = new JPanel(){ { diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index c9f9453..6f2dfe0 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -139,11 +139,9 @@ public class AntWorld implements Environment, Visualizable { // valid movement if(!sc.stayOnCell) { myAnt.getPos().setLocation(sc.potentialNextPos); - if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){ - // the ant will move to a cell that was previously unknown - // TODO: not optimal for going straight for food - // sc.reward = Reward.UNKNOWN_FIELD_EXPLORED; - } + antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown +// TODO: not optimal for going straight for food +// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED; } diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java index c1d03f1..c6ef661 100644 --- a/src/main/java/evironment/antGame/Grid.java +++ b/src/main/java/evironment/antGame/Grid.java @@ -33,7 +33,6 @@ public class Grid { spawnNewFood(initialGrid); spawnObstacles(); initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); - ; } diff --git a/src/main/java/evironment/antGame/gui/AntWorldComponent.java b/src/main/java/evironment/antGame/gui/AntWorldComponent.java index 4b148b5..32e7755 100644 --- a/src/main/java/evironment/antGame/gui/AntWorldComponent.java +++ b/src/main/java/evironment/antGame/gui/AntWorldComponent.java @@ -7,10 +7,8 @@ import javax.swing.*; import java.awt.*; public class AntWorldComponent extends JComponent { - private AntWorld antWorld; public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){ - this.antWorld = antWorld; setLayout(new BorderLayout()); CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10); CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10); diff --git a/src/main/java/evironment/jumpingDino/DinoState.java b/src/main/java/evironment/jumpingDino/DinoState.java index 0a52a7a..3826456 100644 --- a/src/main/java/evironment/jumpingDino/DinoState.java +++ b/src/main/java/evironment/jumpingDino/DinoState.java @@ -2,21 +2,20 @@ package evironment.jumpingDino; import core.State; import core.gui.Visualizable; -import lombok.AllArgsConstructor; import lombok.Getter; -import javax.swing.*; import java.awt.*; import java.io.Serializable; import java.util.Objects; -@AllArgsConstructor @Getter -public class DinoState implements State, Serializable, Visualizable { - private int xDistanceToObstacle; +public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable { private boolean isJumping; - protected final double scale = 0.5; + public DinoState(int xDistanceToObstacle, boolean isJumping) { + super(xDistanceToObstacle); + this.isJumping = isJumping; + } @Override public String toString() { @@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable { } @Override - public JComponent visualize() { - return new JComponent() { - { - setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT))); - setVisible(true); - } - - @Override - protected void paintComponent(Graphics g) { - super.paintComponents(g); - drawObjects(g); - } - }; - } - - public void drawObjects(Graphics g){ - g.setColor(Color.BLACK); - g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2); - - g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE)); - g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) )); - - g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE)); - + protected void drawDinoInfo(Graphics g) { + int dinoY; + if(!isJumping) { + dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE; + g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE)); + } else { + dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT); + g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE)); + } + g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20))); } } diff --git a/src/main/java/evironment/jumpingDino/DinoStateSimple.java b/src/main/java/evironment/jumpingDino/DinoStateSimple.java index 4552489..0c4dfa6 100644 --- a/src/main/java/evironment/jumpingDino/DinoStateSimple.java +++ b/src/main/java/evironment/jumpingDino/DinoStateSimple.java @@ -14,7 +14,7 @@ import java.util.Objects; @Getter public class DinoStateSimple implements State, Serializable, Visualizable { protected final double scale = 0.5; - private int xDistanceToObstacle; + protected int xDistanceToObstacle; @Override public String toString() { @@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable { public JComponent visualize() { return new JComponent() { { - setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT))); + setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT))); setVisible(true); } @@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable { }; } + protected void drawDinoInfo(Graphics g) { + g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE)); + g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40))); + } + public void drawObjects(Graphics g) { g.setColor(Color.BLACK); g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2); - - g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE)); - g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40))); - g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE)); - + drawDinoInfo(g); } } diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index fdc66a4..b17d26c 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -62,9 +62,6 @@ public class DinoWorld implements Environment, Visualizable { return new StepResultEnvironment(generateReturnState(), reward, done, ""); } - protected State generateReturnState(){ - return new DinoStateSimple(getDistanceToObstacle()); - } protected State generateReturnState(){ return new DinoState(getDistanceToObstacle(), dino.isInJump()); } diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index d0f2a7b..e6ea1f5 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -13,13 +13,13 @@ public class JumpingDino { RLController rl = new RLControllerGUI<>( new DinoWorldAdvanced(), - Method.MC_CONTROL_FIRST_VISIT, + Method.MC_CONTROL_EVERY_VISIT, DinoAction.values()); rl.setDelay(200); - rl.setDiscountFactor(9f); + rl.setDiscountFactor(1f); rl.setEpsilon(0.05f); - rl.setLearningRate(0.8f); + rl.setLearningRate(1f); rl.setNrOfEpisodes(100000); rl.start(); }