apply code improvements suggested by intelliJ

This commit is contained in:
Jan Löwenstrom 2020-04-05 14:44:48 +02:00
parent 94ad976a1f
commit 9d1f8dfd46
15 changed files with 45 additions and 66 deletions

View File

@ -1,7 +1,10 @@
package core;
import java.io.Serializable;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
/**
* Implementation of a discrete action space.
@ -18,6 +21,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
actions = new ArrayList<>();
}
@SafeVarargs
public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions));
}

View File

@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
private boolean isEveryVisit;
private final boolean isEveryVisit;
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {

View File

@ -34,7 +34,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
}
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = null;
Map<A, Double> actionValues;
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {

View File

@ -39,6 +39,7 @@ public class RLController<A extends Enum> implements LearningListener {
protected int prevDelay;
protected volatile boolean printNextEpisode;
@SafeVarargs
public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env);
setMethod(method);
@ -102,9 +103,7 @@ public class RLController<A extends Enum> implements LearningListener {
if(learning.isCurrentlyLearning()){
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> {
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
}).start();
new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
@ -179,7 +178,7 @@ public class RLController<A extends Enum> implements LearningListener {
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
printNextEpisode = false;
}

View File

@ -13,6 +13,7 @@ import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
private LearningView learningView;
@SafeVarargs
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
super(env, method, actions);
}
@ -102,7 +103,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
@Override
public void onLearningEnd() {
super.onLearningEnd();
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
}

View File

@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesButton);
}
showQTableButton = new JButton("Show Q-Table");
showQTableButton.addActionListener(e -> {
viewListener.onShowQTable();
});
showQTableButton.addActionListener(e -> viewListener.onShowQTable());
add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox);
add(last100Checkbox);

View File

@ -1,7 +1,6 @@
package evironment.antGame;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;

View File

@ -86,12 +86,12 @@ public class AntState implements State, Visualizable {
public JComponent visualize() {
return new JScrollPane() {
private int cellSize;
private final int paneWidth = 500;
private final int paneHeight = 500;
private Font font;
{
int paneWidth = 500;
int paneHeight = 500;
setPreferredSize(new Dimension(paneWidth, paneHeight));
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize);
JPanel worldPanel = new JPanel(){
{

View File

@ -139,11 +139,9 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
// valid movement
if(!sc.stayOnCell) {
myAnt.getPos().setLocation(sc.potentialNextPos);
if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){
// the ant will move to a cell that was previously unknown
// TODO: not optimal for going straight for food
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
}
antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
// TODO: not optimal for going straight for food
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
}

View File

@ -33,7 +33,6 @@ public class Grid {
spawnNewFood(initialGrid);
spawnObstacles();
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
;
}

View File

@ -7,10 +7,8 @@ import javax.swing.*;
import java.awt.*;
public class AntWorldComponent extends JComponent {
private AntWorld antWorld;
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
this.antWorld = antWorld;
setLayout(new BorderLayout());
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);

View File

@ -2,21 +2,20 @@ package evironment.jumpingDino;
import core.State;
import core.gui.Visualizable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import javax.swing.*;
import java.awt.*;
import java.io.Serializable;
import java.util.Objects;
@AllArgsConstructor
@Getter
public class DinoState implements State, Serializable, Visualizable {
private int xDistanceToObstacle;
public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
private boolean isJumping;
protected final double scale = 0.5;
public DinoState(int xDistanceToObstacle, boolean isJumping) {
super(xDistanceToObstacle);
this.isJumping = isJumping;
}
@Override
public String toString() {
@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable {
}
@Override
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponents(g);
drawObjects(g);
}
};
}
public void drawObjects(Graphics g){
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
protected void drawDinoInfo(Graphics g) {
int dinoY;
if(!isJumping) {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
} else {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
}
g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
}
}

View File

@ -14,7 +14,7 @@ import java.util.Objects;
@Getter
public class DinoStateSimple implements State, Serializable, Visualizable {
protected final double scale = 0.5;
private int xDistanceToObstacle;
protected int xDistanceToObstacle;
@Override
public String toString() {
@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT)));
setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
};
}
protected void drawDinoInfo(Graphics g) {
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
}
public void drawObjects(Graphics g) {
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
drawDinoInfo(g);
}
}

View File

@ -62,9 +62,6 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
return new StepResultEnvironment(generateReturnState(), reward, done, "");
}
protected State generateReturnState(){
return new DinoStateSimple(getDistanceToObstacle());
}
protected State generateReturnState(){
return new DinoState(getDistanceToObstacle(), dino.isInJump());
}

View File

@ -13,13 +13,13 @@ public class JumpingDino {
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorldAdvanced(),
Method.MC_CONTROL_FIRST_VISIT,
Method.MC_CONTROL_EVERY_VISIT,
DinoAction.values());
rl.setDelay(200);
rl.setDiscountFactor(9f);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.05f);
rl.setLearningRate(0.8f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(100000);
rl.start();
}