apply code improvements suggested by intelliJ

This commit is contained in:
Jan Löwenstrom 2020-04-05 14:44:48 +02:00
parent 94ad976a1f
commit 9d1f8dfd46
15 changed files with 45 additions and 66 deletions

View File

@ -1,7 +1,10 @@
package core; package core;
import java.io.Serializable; import java.io.Serializable;
import java.util.*; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
/** /**
* Implementation of a discrete action space. * Implementation of a discrete action space.
@ -18,6 +21,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
actions = new ArrayList<>(); actions = new ArrayList<>();
} }
@SafeVarargs
public ListDiscreteActionSpace(A... actions){ public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions)); this.actions = new ArrayList<>(Arrays.asList(actions));
} }

View File

@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
private Map<Pair<State, A>, Double> returnSum; private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount; private Map<Pair<State, A>, Integer> returnCount;
private boolean isEveryVisit; private final boolean isEveryVisit;
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) { public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {

View File

@ -34,7 +34,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
} }
StepResultEnvironment envResult = null; StepResultEnvironment envResult = null;
Map<A, Double> actionValues = null; Map<A, Double> actionValues;
sumOfRewards = 0; sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {

View File

@ -39,6 +39,7 @@ public class RLController<A extends Enum> implements LearningListener {
protected int prevDelay; protected int prevDelay;
protected volatile boolean printNextEpisode; protected volatile boolean printNextEpisode;
@SafeVarargs
public RLController(Environment<A> env, Method method, A... actions) { public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env); setEnvironment(env);
setMethod(method); setMethod(method);
@ -102,9 +103,7 @@ public class RLController<A extends Enum> implements LearningListener {
if(learning.isCurrentlyLearning()){ if(learning.isCurrentlyLearning()){
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes); ((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
}else{ }else{
new Thread(() -> { new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
}).start();
} }
} else { } else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
@ -179,7 +178,7 @@ public class RLController<A extends Enum> implements LearningListener {
public void onEpisodeEnd(List<Double> rewardHistory) { public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory; latestRewardsHistory = rewardHistory;
if(printNextEpisode) { if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond()); System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
printNextEpisode = false; printNextEpisode = false;
} }

View File

@ -13,6 +13,7 @@ import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener { public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
private LearningView learningView; private LearningView learningView;
@SafeVarargs
public RLControllerGUI(Environment<A> env, Method method, A... actions) { public RLControllerGUI(Environment<A> env, Method method, A... actions) {
super(env, method, actions); super(env, method, actions);
} }
@ -102,7 +103,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
@Override @Override
public void onLearningEnd() { public void onLearningEnd() {
super.onLearningEnd(); super.onLearningEnd();
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : "")); onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory)); SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
} }
} }

View File

@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesButton); add(learnMoreEpisodesButton);
} }
showQTableButton = new JButton("Show Q-Table"); showQTableButton = new JButton("Show Q-Table");
showQTableButton.addActionListener(e -> { showQTableButton.addActionListener(e -> viewListener.onShowQTable());
viewListener.onShowQTable();
});
add(drawEnvironmentCheckbox); add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox); add(smoothGraphCheckbox);
add(last100Checkbox); add(last100Checkbox);

View File

@ -1,7 +1,6 @@
package evironment.antGame; package evironment.antGame;
import lombok.AccessLevel; import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;

View File

@ -86,12 +86,12 @@ public class AntState implements State, Visualizable {
public JComponent visualize() { public JComponent visualize() {
return new JScrollPane() { return new JScrollPane() {
private int cellSize; private int cellSize;
private final int paneWidth = 500;
private final int paneHeight = 500;
private Font font; private Font font;
{ {
int paneWidth = 500;
int paneHeight = 500;
setPreferredSize(new Dimension(paneWidth, paneHeight)); setPreferredSize(new Dimension(paneWidth, paneHeight));
cellSize = (paneWidth- knownWorld.length) /knownWorld.length; cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize); font = new Font("plain", Font.BOLD, cellSize);
JPanel worldPanel = new JPanel(){ JPanel worldPanel = new JPanel(){
{ {

View File

@ -139,11 +139,9 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
// valid movement // valid movement
if(!sc.stayOnCell) { if(!sc.stayOnCell) {
myAnt.getPos().setLocation(sc.potentialNextPos); myAnt.getPos().setLocation(sc.potentialNextPos);
if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){ antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
// the ant will move to a cell that was previously unknown // TODO: not optimal for going straight for food
// TODO: not optimal for going straight for food // sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
}
} }

View File

@ -33,7 +33,6 @@ public class Grid {
spawnNewFood(initialGrid); spawnNewFood(initialGrid);
spawnObstacles(); spawnObstacles();
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
;
} }

View File

@ -7,10 +7,8 @@ import javax.swing.*;
import java.awt.*; import java.awt.*;
public class AntWorldComponent extends JComponent { public class AntWorldComponent extends JComponent {
private AntWorld antWorld;
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){ public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
this.antWorld = antWorld;
setLayout(new BorderLayout()); setLayout(new BorderLayout());
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10); CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10); CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);

View File

@ -2,21 +2,20 @@ package evironment.jumpingDino;
import core.State; import core.State;
import core.gui.Visualizable; import core.gui.Visualizable;
import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import javax.swing.*;
import java.awt.*; import java.awt.*;
import java.io.Serializable; import java.io.Serializable;
import java.util.Objects; import java.util.Objects;
@AllArgsConstructor
@Getter @Getter
public class DinoState implements State, Serializable, Visualizable { public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
private int xDistanceToObstacle;
private boolean isJumping; private boolean isJumping;
protected final double scale = 0.5; public DinoState(int xDistanceToObstacle, boolean isJumping) {
super(xDistanceToObstacle);
this.isJumping = isJumping;
}
@Override @Override
public String toString() { public String toString() {
@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable {
} }
@Override @Override
public JComponent visualize() { protected void drawDinoInfo(Graphics g) {
return new JComponent() { int dinoY;
{ if(!isJumping) {
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT))); dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
setVisible(true); g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
} else {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
} }
g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
@Override
protected void paintComponent(Graphics g) {
super.paintComponents(g);
drawObjects(g);
}
};
}
public void drawObjects(Graphics g){
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
} }
} }

View File

@ -14,7 +14,7 @@ import java.util.Objects;
@Getter @Getter
public class DinoStateSimple implements State, Serializable, Visualizable { public class DinoStateSimple implements State, Serializable, Visualizable {
protected final double scale = 0.5; protected final double scale = 0.5;
private int xDistanceToObstacle; protected int xDistanceToObstacle;
@Override @Override
public String toString() { public String toString() {
@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
public JComponent visualize() { public JComponent visualize() {
return new JComponent() { return new JComponent() {
{ {
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT))); setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
setVisible(true); setVisible(true);
} }
@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
}; };
} }
protected void drawDinoInfo(Graphics g) {
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
}
public void drawObjects(Graphics g) { public void drawObjects(Graphics g) {
g.setColor(Color.BLACK); g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2); g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE)); g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
drawDinoInfo(g);
} }
} }

View File

@ -62,9 +62,6 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
return new StepResultEnvironment(generateReturnState(), reward, done, ""); return new StepResultEnvironment(generateReturnState(), reward, done, "");
} }
protected State generateReturnState(){
return new DinoStateSimple(getDistanceToObstacle());
}
protected State generateReturnState(){ protected State generateReturnState(){
return new DinoState(getDistanceToObstacle(), dino.isInJump()); return new DinoState(getDistanceToObstacle(), dino.isInJump());
} }

View File

@ -13,13 +13,13 @@ public class JumpingDino {
RLController<DinoAction> rl = new RLControllerGUI<>( RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorldAdvanced(), new DinoWorldAdvanced(),
Method.MC_CONTROL_FIRST_VISIT, Method.MC_CONTROL_EVERY_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(200); rl.setDelay(200);
rl.setDiscountFactor(9f); rl.setDiscountFactor(1f);
rl.setEpsilon(0.05f); rl.setEpsilon(0.05f);
rl.setLearningRate(0.8f); rl.setLearningRate(1f);
rl.setNrOfEpisodes(100000); rl.setNrOfEpisodes(100000);
rl.start(); rl.start();
} }