apply code improvements suggested by intelliJ
This commit is contained in:
parent
94ad976a1f
commit
9d1f8dfd46
|
@ -1,7 +1,10 @@
|
|||
package core;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Implementation of a discrete action space.
|
||||
|
@ -18,6 +21,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
|
|||
actions = new ArrayList<>();
|
||||
}
|
||||
|
||||
@SafeVarargs
|
||||
public ListDiscreteActionSpace(A... actions){
|
||||
this.actions = new ArrayList<>(Arrays.asList(actions));
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
|||
|
||||
private Map<Pair<State, A>, Double> returnSum;
|
||||
private Map<Pair<State, A>, Integer> returnCount;
|
||||
private boolean isEveryVisit;
|
||||
private final boolean isEveryVisit;
|
||||
|
||||
|
||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||
|
|
|
@ -34,7 +34,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
}
|
||||
|
||||
StepResultEnvironment envResult = null;
|
||||
Map<A, Double> actionValues = null;
|
||||
Map<A, Double> actionValues;
|
||||
|
||||
sumOfRewards = 0;
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
|
|
|
@ -39,6 +39,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
protected int prevDelay;
|
||||
protected volatile boolean printNextEpisode;
|
||||
|
||||
@SafeVarargs
|
||||
public RLController(Environment<A> env, Method method, A... actions) {
|
||||
setEnvironment(env);
|
||||
setMethod(method);
|
||||
|
@ -102,9 +103,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
if(learning.isCurrentlyLearning()){
|
||||
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
|
||||
}else{
|
||||
new Thread(() -> {
|
||||
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
|
||||
}).start();
|
||||
new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
|
||||
}
|
||||
} else {
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
|
@ -179,7 +178,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
latestRewardsHistory = rewardHistory;
|
||||
if(printNextEpisode) {
|
||||
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||
System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
|
||||
printNextEpisode = false;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import java.util.List;
|
|||
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
|
||||
private LearningView learningView;
|
||||
|
||||
@SafeVarargs
|
||||
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
||||
super(env, method, actions);
|
||||
}
|
||||
|
@ -102,7 +103,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
|
|||
@Override
|
||||
public void onLearningEnd() {
|
||||
super.onLearningEnd();
|
||||
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
|
||||
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel {
|
|||
add(learnMoreEpisodesButton);
|
||||
}
|
||||
showQTableButton = new JButton("Show Q-Table");
|
||||
showQTableButton.addActionListener(e -> {
|
||||
viewListener.onShowQTable();
|
||||
});
|
||||
showQTableButton.addActionListener(e -> viewListener.onShowQTable());
|
||||
add(drawEnvironmentCheckbox);
|
||||
add(smoothGraphCheckbox);
|
||||
add(last100Checkbox);
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package evironment.antGame;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
|
|
|
@ -86,12 +86,12 @@ public class AntState implements State, Visualizable {
|
|||
public JComponent visualize() {
|
||||
return new JScrollPane() {
|
||||
private int cellSize;
|
||||
private final int paneWidth = 500;
|
||||
private final int paneHeight = 500;
|
||||
private Font font;
|
||||
{
|
||||
int paneWidth = 500;
|
||||
int paneHeight = 500;
|
||||
setPreferredSize(new Dimension(paneWidth, paneHeight));
|
||||
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
|
||||
cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
|
||||
font = new Font("plain", Font.BOLD, cellSize);
|
||||
JPanel worldPanel = new JPanel(){
|
||||
{
|
||||
|
|
|
@ -139,11 +139,9 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
|
|||
// valid movement
|
||||
if(!sc.stayOnCell) {
|
||||
myAnt.getPos().setLocation(sc.potentialNextPos);
|
||||
if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){
|
||||
// the ant will move to a cell that was previously unknown
|
||||
// TODO: not optimal for going straight for food
|
||||
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
|
||||
}
|
||||
antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
|
||||
// TODO: not optimal for going straight for food
|
||||
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ public class Grid {
|
|||
spawnNewFood(initialGrid);
|
||||
spawnObstacles();
|
||||
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
||||
;
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -7,10 +7,8 @@ import javax.swing.*;
|
|||
import java.awt.*;
|
||||
|
||||
public class AntWorldComponent extends JComponent {
|
||||
private AntWorld antWorld;
|
||||
|
||||
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
||||
this.antWorld = antWorld;
|
||||
setLayout(new BorderLayout());
|
||||
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
||||
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
||||
|
|
|
@ -2,21 +2,20 @@ package evironment.jumpingDino;
|
|||
|
||||
import core.State;
|
||||
import core.gui.Visualizable;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.io.Serializable;
|
||||
import java.util.Objects;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public class DinoState implements State, Serializable, Visualizable {
|
||||
private int xDistanceToObstacle;
|
||||
public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
|
||||
private boolean isJumping;
|
||||
|
||||
protected final double scale = 0.5;
|
||||
public DinoState(int xDistanceToObstacle, boolean isJumping) {
|
||||
super(xDistanceToObstacle);
|
||||
this.isJumping = isJumping;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public JComponent visualize() {
|
||||
return new JComponent() {
|
||||
{
|
||||
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void paintComponent(Graphics g) {
|
||||
super.paintComponents(g);
|
||||
drawObjects(g);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public void drawObjects(Graphics g){
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
||||
|
||||
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
|
||||
|
||||
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
|
||||
|
||||
protected void drawDinoInfo(Graphics g) {
|
||||
int dinoY;
|
||||
if(!isJumping) {
|
||||
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
|
||||
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||
} else {
|
||||
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
|
||||
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||
}
|
||||
g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import java.util.Objects;
|
|||
@Getter
|
||||
public class DinoStateSimple implements State, Serializable, Visualizable {
|
||||
protected final double scale = 0.5;
|
||||
private int xDistanceToObstacle;
|
||||
protected int xDistanceToObstacle;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
|
|||
public JComponent visualize() {
|
||||
return new JComponent() {
|
||||
{
|
||||
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT)));
|
||||
setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
|
@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
|
|||
};
|
||||
}
|
||||
|
||||
protected void drawDinoInfo(Graphics g) {
|
||||
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
|
||||
}
|
||||
|
||||
public void drawObjects(Graphics g) {
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
||||
|
||||
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
|
||||
|
||||
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
|
||||
|
||||
drawDinoInfo(g);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,9 +62,6 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
return new StepResultEnvironment(generateReturnState(), reward, done, "");
|
||||
}
|
||||
|
||||
protected State generateReturnState(){
|
||||
return new DinoStateSimple(getDistanceToObstacle());
|
||||
}
|
||||
protected State generateReturnState(){
|
||||
return new DinoState(getDistanceToObstacle(), dino.isInJump());
|
||||
}
|
||||
|
|
|
@ -13,13 +13,13 @@ public class JumpingDino {
|
|||
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
new DinoWorldAdvanced(),
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
Method.MC_CONTROL_EVERY_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
rl.setDiscountFactor(9f);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.05f);
|
||||
rl.setLearningRate(0.8f);
|
||||
rl.setLearningRate(1f);
|
||||
rl.setNrOfEpisodes(100000);
|
||||
rl.start();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue