apply code improvements suggested by intelliJ
This commit is contained in:
parent
94ad976a1f
commit
9d1f8dfd46
|
@ -1,7 +1,10 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of a discrete action space.
|
* Implementation of a discrete action space.
|
||||||
|
@ -18,6 +21,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
|
||||||
actions = new ArrayList<>();
|
actions = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SafeVarargs
|
||||||
public ListDiscreteActionSpace(A... actions){
|
public ListDiscreteActionSpace(A... actions){
|
||||||
this.actions = new ArrayList<>(Arrays.asList(actions));
|
this.actions = new ArrayList<>(Arrays.asList(actions));
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
||||||
|
|
||||||
private Map<Pair<State, A>, Double> returnSum;
|
private Map<Pair<State, A>, Double> returnSum;
|
||||||
private Map<Pair<State, A>, Integer> returnCount;
|
private Map<Pair<State, A>, Integer> returnCount;
|
||||||
private boolean isEveryVisit;
|
private final boolean isEveryVisit;
|
||||||
|
|
||||||
|
|
||||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||||
|
|
|
@ -34,7 +34,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
||||||
}
|
}
|
||||||
|
|
||||||
StepResultEnvironment envResult = null;
|
StepResultEnvironment envResult = null;
|
||||||
Map<A, Double> actionValues = null;
|
Map<A, Double> actionValues;
|
||||||
|
|
||||||
sumOfRewards = 0;
|
sumOfRewards = 0;
|
||||||
while(envResult == null || !envResult.isDone()) {
|
while(envResult == null || !envResult.isDone()) {
|
||||||
|
|
|
@ -39,6 +39,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
protected int prevDelay;
|
protected int prevDelay;
|
||||||
protected volatile boolean printNextEpisode;
|
protected volatile boolean printNextEpisode;
|
||||||
|
|
||||||
|
@SafeVarargs
|
||||||
public RLController(Environment<A> env, Method method, A... actions) {
|
public RLController(Environment<A> env, Method method, A... actions) {
|
||||||
setEnvironment(env);
|
setEnvironment(env);
|
||||||
setMethod(method);
|
setMethod(method);
|
||||||
|
@ -102,9 +103,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
if(learning.isCurrentlyLearning()){
|
if(learning.isCurrentlyLearning()){
|
||||||
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
|
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
|
||||||
}else{
|
}else{
|
||||||
new Thread(() -> {
|
new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
|
||||||
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
|
|
||||||
}).start();
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||||
|
@ -179,7 +178,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||||
latestRewardsHistory = rewardHistory;
|
latestRewardsHistory = rewardHistory;
|
||||||
if(printNextEpisode) {
|
if(printNextEpisode) {
|
||||||
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||||
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
|
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
|
||||||
printNextEpisode = false;
|
printNextEpisode = false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import java.util.List;
|
||||||
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
|
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
|
||||||
private LearningView learningView;
|
private LearningView learningView;
|
||||||
|
|
||||||
|
@SafeVarargs
|
||||||
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
||||||
super(env, method, actions);
|
super(env, method, actions);
|
||||||
}
|
}
|
||||||
|
@ -102,7 +103,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
|
||||||
@Override
|
@Override
|
||||||
public void onLearningEnd() {
|
public void onLearningEnd() {
|
||||||
super.onLearningEnd();
|
super.onLearningEnd();
|
||||||
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel {
|
||||||
add(learnMoreEpisodesButton);
|
add(learnMoreEpisodesButton);
|
||||||
}
|
}
|
||||||
showQTableButton = new JButton("Show Q-Table");
|
showQTableButton = new JButton("Show Q-Table");
|
||||||
showQTableButton.addActionListener(e -> {
|
showQTableButton.addActionListener(e -> viewListener.onShowQTable());
|
||||||
viewListener.onShowQTable();
|
|
||||||
});
|
|
||||||
add(drawEnvironmentCheckbox);
|
add(drawEnvironmentCheckbox);
|
||||||
add(smoothGraphCheckbox);
|
add(smoothGraphCheckbox);
|
||||||
add(last100Checkbox);
|
add(last100Checkbox);
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
|
|
|
@ -86,12 +86,12 @@ public class AntState implements State, Visualizable {
|
||||||
public JComponent visualize() {
|
public JComponent visualize() {
|
||||||
return new JScrollPane() {
|
return new JScrollPane() {
|
||||||
private int cellSize;
|
private int cellSize;
|
||||||
private final int paneWidth = 500;
|
|
||||||
private final int paneHeight = 500;
|
|
||||||
private Font font;
|
private Font font;
|
||||||
{
|
{
|
||||||
|
int paneWidth = 500;
|
||||||
|
int paneHeight = 500;
|
||||||
setPreferredSize(new Dimension(paneWidth, paneHeight));
|
setPreferredSize(new Dimension(paneWidth, paneHeight));
|
||||||
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
|
cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
|
||||||
font = new Font("plain", Font.BOLD, cellSize);
|
font = new Font("plain", Font.BOLD, cellSize);
|
||||||
JPanel worldPanel = new JPanel(){
|
JPanel worldPanel = new JPanel(){
|
||||||
{
|
{
|
||||||
|
|
|
@ -139,11 +139,9 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
|
||||||
// valid movement
|
// valid movement
|
||||||
if(!sc.stayOnCell) {
|
if(!sc.stayOnCell) {
|
||||||
myAnt.getPos().setLocation(sc.potentialNextPos);
|
myAnt.getPos().setLocation(sc.potentialNextPos);
|
||||||
if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){
|
antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
|
||||||
// the ant will move to a cell that was previously unknown
|
// TODO: not optimal for going straight for food
|
||||||
// TODO: not optimal for going straight for food
|
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
|
||||||
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,6 @@ public class Grid {
|
||||||
spawnNewFood(initialGrid);
|
spawnNewFood(initialGrid);
|
||||||
spawnObstacles();
|
spawnObstacles();
|
||||||
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
||||||
;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,8 @@ import javax.swing.*;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
||||||
public class AntWorldComponent extends JComponent {
|
public class AntWorldComponent extends JComponent {
|
||||||
private AntWorld antWorld;
|
|
||||||
|
|
||||||
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
||||||
this.antWorld = antWorld;
|
|
||||||
setLayout(new BorderLayout());
|
setLayout(new BorderLayout());
|
||||||
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
||||||
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
||||||
|
|
|
@ -2,21 +2,20 @@ package evironment.jumpingDino;
|
||||||
|
|
||||||
import core.State;
|
import core.State;
|
||||||
import core.gui.Visualizable;
|
import core.gui.Visualizable;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Getter
|
@Getter
|
||||||
public class DinoState implements State, Serializable, Visualizable {
|
public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
|
||||||
private int xDistanceToObstacle;
|
|
||||||
private boolean isJumping;
|
private boolean isJumping;
|
||||||
|
|
||||||
protected final double scale = 0.5;
|
public DinoState(int xDistanceToObstacle, boolean isJumping) {
|
||||||
|
super(xDistanceToObstacle);
|
||||||
|
this.isJumping = isJumping;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public JComponent visualize() {
|
protected void drawDinoInfo(Graphics g) {
|
||||||
return new JComponent() {
|
int dinoY;
|
||||||
{
|
if(!isJumping) {
|
||||||
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
|
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
|
||||||
setVisible(true);
|
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||||
}
|
} else {
|
||||||
|
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
|
||||||
@Override
|
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||||
protected void paintComponent(Graphics g) {
|
}
|
||||||
super.paintComponents(g);
|
g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
|
||||||
drawObjects(g);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
public void drawObjects(Graphics g){
|
|
||||||
g.setColor(Color.BLACK);
|
|
||||||
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
|
||||||
|
|
||||||
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
|
||||||
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
|
|
||||||
|
|
||||||
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import java.util.Objects;
|
||||||
@Getter
|
@Getter
|
||||||
public class DinoStateSimple implements State, Serializable, Visualizable {
|
public class DinoStateSimple implements State, Serializable, Visualizable {
|
||||||
protected final double scale = 0.5;
|
protected final double scale = 0.5;
|
||||||
private int xDistanceToObstacle;
|
protected int xDistanceToObstacle;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
|
||||||
public JComponent visualize() {
|
public JComponent visualize() {
|
||||||
return new JComponent() {
|
return new JComponent() {
|
||||||
{
|
{
|
||||||
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT)));
|
setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void drawDinoInfo(Graphics g) {
|
||||||
|
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||||
|
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
|
||||||
|
}
|
||||||
|
|
||||||
public void drawObjects(Graphics g) {
|
public void drawObjects(Graphics g) {
|
||||||
g.setColor(Color.BLACK);
|
g.setColor(Color.BLACK);
|
||||||
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
||||||
|
|
||||||
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
|
||||||
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
|
|
||||||
|
|
||||||
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
|
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
|
||||||
|
drawDinoInfo(g);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,9 +62,6 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
return new StepResultEnvironment(generateReturnState(), reward, done, "");
|
return new StepResultEnvironment(generateReturnState(), reward, done, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
protected State generateReturnState(){
|
|
||||||
return new DinoStateSimple(getDistanceToObstacle());
|
|
||||||
}
|
|
||||||
protected State generateReturnState(){
|
protected State generateReturnState(){
|
||||||
return new DinoState(getDistanceToObstacle(), dino.isInJump());
|
return new DinoState(getDistanceToObstacle(), dino.isInJump());
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,13 +13,13 @@ public class JumpingDino {
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
new DinoWorldAdvanced(),
|
new DinoWorldAdvanced(),
|
||||||
Method.MC_CONTROL_FIRST_VISIT,
|
Method.MC_CONTROL_EVERY_VISIT,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(200);
|
rl.setDelay(200);
|
||||||
rl.setDiscountFactor(9f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(0.05f);
|
rl.setEpsilon(0.05f);
|
||||||
rl.setLearningRate(0.8f);
|
rl.setLearningRate(1f);
|
||||||
rl.setNrOfEpisodes(100000);
|
rl.setNrOfEpisodes(100000);
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue