diff --git a/.idea/compiler.xml b/.idea/compiler.xml
index a1757ae..95a88ae 100644
--- a/.idea/compiler.xml
+++ b/.idea/compiler.xml
@@ -5,4 +5,4 @@
-
\ No newline at end of file
+
diff --git a/build.gradle b/build.gradle
index a28b89e..7ea77b1 100644
--- a/build.gradle
+++ b/build.gradle
@@ -13,6 +13,9 @@ repositories {
}
dependencies {
+ // https://mvnrepository.com/artifact/org.jfree/jfreechart
+ // https://mvnrepository.com/artifact/org.knowm.xchart/xchart
+ compile group: 'org.knowm.xchart', name: 'xchart', version: '3.2.2'
testCompile group: 'junit', name: 'junit', version: '4.12'
compileOnly 'org.projectlombok:lombok:1.18.10'
annotationProcessor 'org.projectlombok:lombok:1.18.10'
diff --git a/src/main/java/core/DiscreteActionSpace.java b/src/main/java/core/DiscreteActionSpace.java
index a6b38fe..a5caf2b 100644
--- a/src/main/java/core/DiscreteActionSpace.java
+++ b/src/main/java/core/DiscreteActionSpace.java
@@ -3,7 +3,7 @@ package core;
import java.util.List;
public interface DiscreteActionSpace {
- int getNumberOfAction();
+ int getNumberOfActions();
void addAction(A a);
void addActions(A... as);
List getAllActions();
diff --git a/src/main/java/core/LearningConfig.java b/src/main/java/core/LearningConfig.java
new file mode 100644
index 0000000..916de16
--- /dev/null
+++ b/src/main/java/core/LearningConfig.java
@@ -0,0 +1,7 @@
+package core;
+
+public class LearningConfig {
+ public static final int DEFAULT_DELAY = 1;
+ public static final float DEFAULT_EPSILON = 0.1f;
+ public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
+}
diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java
index 76babaf..42de87a 100644
--- a/src/main/java/core/ListDiscreteActionSpace.java
+++ b/src/main/java/core/ListDiscreteActionSpace.java
@@ -32,7 +32,7 @@ public class ListDiscreteActionSpace implements DiscreteActionSp
}
@Override
- public int getNumberOfAction(){
+ public int getNumberOfActions(){
return actions.size();
}
}
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index 8d71a78..58a285a 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -2,26 +2,62 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
+import core.LearningConfig;
import core.StateActionTable;
+import core.listener.LearningListener;
import core.policy.Policy;
+import lombok.Getter;
+import lombok.Setter;
+import javax.swing.*;
+import java.util.HashSet;
+import java.util.Set;
+
+@Getter
public abstract class Learning {
protected Policy policy;
protected DiscreteActionSpace actionSpace;
protected StateActionTable stateActionTable;
protected Environment environment;
protected float discountFactor;
+ @Setter
protected float epsilon;
+ protected Set learningListeners;
+ @Setter
+ protected int delay;
- public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay){
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.epsilon = epsilon;
- }
- public Learning(Environment environment, DiscreteActionSpace actionSpace){
- this(environment, actionSpace, 1.0f, 0.1f);
+ this.delay = delay;
+ learningListeners = new HashSet<>();
}
- public abstract void learn(int nrOfEpisodes, int delay);
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){
+ this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
+ }
+
+ public Learning(Environment environment, DiscreteActionSpace actionSpace){
+ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
+ }
+
+ public abstract void learn(int nrOfEpisodes);
+
+ public void addListener(LearningListener learningListener){
+ learningListeners.add(learningListener);
+ }
+
+ protected void dispatchEpisodeEnd(double sum){
+ for(LearningListener l: learningListeners) {
+ l.onEpisodeEnd(sum);
+ }
+ }
+
+ protected void dispatchEpisodeStart(){
+ for(LearningListener l: learningListeners){
+ l.onEpisodeStart();
+ }
+ }
}
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
index 54450e8..d608b80 100644
--- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
+++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
@@ -34,22 +34,21 @@ public class MonteCarloOnPolicyEGreedy extends Learning {
}
@Override
- public void learn(int nrOfEpisodes, int delay) {
+ public void learn(int nrOfEpisodes) {
Map, Double> returnSum = new HashMap<>();
Map, Integer> returnCount = new HashMap<>();
- State startingState = environment.reset();
for(int i = 0; i < nrOfEpisodes; ++i) {
List> episode = new ArrayList<>();
State state = environment.reset();
- double rewardSum = 0;
+ double sumOfRewards = 0;
for(int j=0; j < 10; ++j){
Map actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
StepResultEnvironment envResult = environment.step(chosenAction);
State nextState = envResult.getState();
- rewardSum += envResult.getReward();
+ sumOfRewards += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
if(envResult.isDone()) break;
@@ -57,13 +56,14 @@ public class MonteCarloOnPolicyEGreedy extends Learning {
state = nextState;
try {
- Thread.sleep(1);
+ Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
- System.out.printf("Episode %d \t Reward: %f \n", i, rewardSum);
+ dispatchEpisodeEnd(sumOfRewards);
+ System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards);
Set> stateActionPairs = new HashSet<>();
for(StepResult sr: episode){
diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java
new file mode 100644
index 0000000..b2da8cb
--- /dev/null
+++ b/src/main/java/core/algo/Method.java
@@ -0,0 +1,5 @@
+package core.algo;
+
+public enum Method {
+ MC_ONPOLICY_EGREEDY, TD_ONPOLICY
+}
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
new file mode 100644
index 0000000..c65e62d
--- /dev/null
+++ b/src/main/java/core/controller/RLController.java
@@ -0,0 +1,81 @@
+package core.controller;
+
+import core.DiscreteActionSpace;
+import core.Environment;
+import core.ListDiscreteActionSpace;
+import core.algo.Learning;
+import core.algo.Method;
+import core.algo.mc.MonteCarloOnPolicyEGreedy;
+import core.gui.View;
+
+import javax.swing.*;
+import java.util.Optional;
+
+public class RLController implements ViewListener{
+ protected Environment environment;
+ protected Learning learning;
+ protected DiscreteActionSpace discreteActionSpace;
+ protected View view;
+ private int delay;
+ private int nrOfEpisodes;
+ private Method method;
+
+ public RLController(){
+ }
+
+ public void start(){
+ if(environment == null || discreteActionSpace == null || method == null){
+ throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
+ }
+
+ switch (method){
+ case MC_ONPOLICY_EGREEDY:
+ learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
+ break;
+ case TD_ONPOLICY:
+ break;
+ default:
+ throw new RuntimeException("Undefined method");
+ }
+ SwingUtilities.invokeLater(() ->{
+ view = new View<>(learning, this);
+ learning.addListener(view);
+ });
+ learning.learn(nrOfEpisodes);
+ }
+
+ @Override
+ public void onEpsilonChange(float epsilon) {
+ learning.setEpsilon(epsilon);
+ SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
+ }
+
+ @Override
+ public void onDelayChange(int delay) {
+ }
+
+ public RLController setMethod(Method method){
+ this.method = method;
+ return this;
+ }
+ public RLController setEnvironment(Environment environment){
+ this.environment = environment;
+ return this;
+ }
+ @SafeVarargs
+ public final RLController setAllowedActions(A... actions){
+ this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
+ return this;
+ }
+
+ public RLController setDelay(int delay){
+ this.delay = delay;
+ return this;
+ }
+
+ public RLController setEpisodes(int nrOfEpisodes){
+ this.nrOfEpisodes = nrOfEpisodes;
+ return this;
+ }
+
+}
diff --git a/src/main/java/core/controller/ViewListener.java b/src/main/java/core/controller/ViewListener.java
new file mode 100644
index 0000000..578512a
--- /dev/null
+++ b/src/main/java/core/controller/ViewListener.java
@@ -0,0 +1,6 @@
+package core.controller;
+
+public interface ViewListener {
+ void onEpsilonChange(float epsilon);
+ void onDelayChange(int delay);
+}
diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java
new file mode 100644
index 0000000..cbbd6ef
--- /dev/null
+++ b/src/main/java/core/gui/LearningInfoPanel.java
@@ -0,0 +1,41 @@
+package core.gui;
+
+import core.algo.Learning;
+import core.controller.ViewListener;
+
+import javax.swing.*;
+
+public class LearningInfoPanel extends JPanel {
+ private Learning learning;
+ private JLabel policyLabel;
+ private JLabel discountLabel;
+ private JLabel epsilonLabel;
+ private JSlider epsilonSlider;
+ private JSlider delaySlider;
+
+ public LearningInfoPanel(Learning learning, ViewListener viewListener){
+ this.learning = learning;
+ setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
+ policyLabel = new JLabel();
+ discountLabel = new JLabel();
+ epsilonLabel = new JLabel();
+ epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
+ epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
+ add(policyLabel);
+ add(discountLabel);
+ add(epsilonLabel);
+ add(epsilonSlider);
+ refreshLabels();
+ setVisible(true);
+ }
+
+ public void refreshLabels(){
+ policyLabel.setText("Policy: " + learning.getPolicy().getClass());
+ discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
+ epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
+ }
+
+ protected JSlider getEpsilonSlider(){
+ return epsilonSlider;
+ }
+}
diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java
new file mode 100644
index 0000000..8939ac2
--- /dev/null
+++ b/src/main/java/core/gui/View.java
@@ -0,0 +1,102 @@
+package core.gui;
+
+import core.algo.Learning;
+import core.controller.ViewListener;
+import core.listener.LearningListener;
+import lombok.Getter;
+import org.knowm.xchart.QuickChart;
+import org.knowm.xchart.XChartPanel;
+import org.knowm.xchart.XYChart;
+
+import javax.swing.*;
+import java.awt.*;
+import java.util.ArrayList;
+import java.util.List;
+
+public class View implements LearningListener {
+ private Learning learning;
+ @Getter
+ private XYChart chart;
+ @Getter
+ private LearningInfoPanel learningInfoPanel;
+ @Getter
+ private JFrame mainFrame;
+ private XChartPanel rewardChartPanel;
+ private ViewListener viewListener;
+ private List rewardHistory;
+
+ public View(Learning learning, ViewListener viewListener){
+ this.learning = learning;
+ this.viewListener = viewListener;
+ rewardHistory = new ArrayList<>();
+ this.initMainFrame();
+ }
+
+ private void initMainFrame(){
+ mainFrame = new JFrame();
+ mainFrame.setPreferredSize(new Dimension(1280, 720));
+ mainFrame.setLayout(new BorderLayout());
+
+ initLearningInfoPanel();
+ initRewardChart();
+
+ mainFrame.add(BorderLayout.WEST, learningInfoPanel);
+ mainFrame.add(BorderLayout.CENTER, rewardChartPanel);
+
+ mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
+ mainFrame.pack();
+ mainFrame.setVisible(true);
+ }
+
+ private void initLearningInfoPanel(){
+ learningInfoPanel = new LearningInfoPanel(learning, viewListener);
+ }
+
+ private void initRewardChart(){
+ chart =
+ QuickChart.getChart(
+ "Rewards per Episode",
+ "Episode",
+ "Reward",
+ "randomWalk",
+ new double[] {0},
+ new double[] {0});
+ chart.getStyler().setLegendVisible(true);
+ chart.getStyler().setXAxisTicksVisible(true);
+ rewardChartPanel = new XChartPanel<>(chart);
+ rewardChartPanel.setPreferredSize(new Dimension(300,300));
+ }
+
+ public void showState(Visualizable state){
+ new JFrame(){
+ {
+ JComponent stateComponent = state.visualize();
+ setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
+ add(stateComponent);
+ setVisible(true);
+ }
+ };
+ }
+
+ public void updateRewardGraph(double recentReward){
+ rewardHistory.add(recentReward);
+ chart.updateXYSeries("randomWalk", null, rewardHistory, null);
+ rewardChartPanel.revalidate();
+ rewardChartPanel.repaint();
+ }
+
+ public void updateLearningInfoPanel(){
+ this.learningInfoPanel.refreshLabels();
+ }
+
+
+ @Override
+ public void onEpisodeEnd(double sumOfRewards) {
+ SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
+ }
+
+ @Override
+ public void onEpisodeStart() {
+
+ }
+}
diff --git a/src/main/java/core/gui/Visualizable.java b/src/main/java/core/gui/Visualizable.java
new file mode 100644
index 0000000..e144a40
--- /dev/null
+++ b/src/main/java/core/gui/Visualizable.java
@@ -0,0 +1,7 @@
+package core.gui;
+
+import javax.swing.*;
+
+public interface Visualizable {
+ JComponent visualize();
+}
diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java
new file mode 100644
index 0000000..5a9d287
--- /dev/null
+++ b/src/main/java/core/listener/LearningListener.java
@@ -0,0 +1,6 @@
+package core.listener;
+
+public interface LearningListener{
+ void onEpisodeEnd(double sumOfRewards);
+ void onEpisodeStart();
+}
diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java
index d1a31b7..8c5bda7 100644
--- a/src/main/java/evironment/antGame/AntState.java
+++ b/src/main/java/evironment/antGame/AntState.java
@@ -1,7 +1,10 @@
package evironment.antGame;
import core.State;
+import core.gui.Visualizable;
+import evironment.antGame.gui.CellColor;
+import javax.swing.*;
import java.awt.*;
import java.util.Arrays;
@@ -10,7 +13,7 @@ import java.util.Arrays;
* Essentially a snapshot of the current Ant Agent
* and therefor has to be deep copied
*/
-public class AntState implements State {
+public class AntState implements State, Visualizable {
private final Cell[][] knownWorld;
private final Point pos;
private final boolean hasFood;
@@ -29,12 +32,12 @@ public class AntState implements State {
int unknown = 0;
int diff = 0;
- for (int i = 0; i < knownWorld.length; i++) {
- for (int j = 0; j < knownWorld[i].length; j++) {
- if(knownWorld[i][j].getType() == CellType.UNKNOWN){
+ for (Cell[] cells : knownWorld) {
+ for (Cell cell : cells) {
+ if (cell.getType() == CellType.UNKNOWN) {
unknown += 1;
- }else{
- diff +=1;
+ } else {
+ diff += 1;
}
}
}
@@ -62,7 +65,7 @@ public class AntState implements State {
@Override
public String toString(){
return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld));
-}
+ }
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
@Override
@@ -89,4 +92,62 @@ public class AntState implements State {
}
return super.equals(obj);
}
+
+ @Override
+ public JComponent visualize() {
+ return new JScrollPane() {
+ private int cellSize;
+ private final int paneWidth = 500;
+ private final int paneHeight = 500;
+ private Font font;
+ {
+ setPreferredSize(new Dimension(paneWidth, paneHeight));
+ cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
+ font = new Font("plain", Font.BOLD, cellSize);
+ JPanel worldPanel = new JPanel(){
+ {
+ setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
+ setVisible(true);
+
+ addMouseWheelListener(e -> {
+ if(e.getWheelRotation() > 0){
+ cellSize -= 1;
+ }else {
+ cellSize += 1;
+ }
+ font = new Font("plain", Font.BOLD, cellSize);
+ setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
+ revalidate();
+ repaint();
+ });
+ }
+
+ @Override
+ public void paintComponent(Graphics g) {
+ super.paintComponent(g);
+ for (int i = 0; i < knownWorld.length; i++) {
+ for (int j = 0; j < knownWorld[0].length; j++) {
+ g.setColor(Color.BLACK);
+ g.drawRect(i*cellSize, j*cellSize, cellSize, cellSize);
+ g.setColor(CellColor.map.get(knownWorld[i][j].getType()));
+ if(knownWorld[i][j].getFood() > 0){
+ g.setColor(Color.YELLOW);
+ }
+ g.fillRect(i*cellSize+1, j*cellSize+1, cellSize -1, cellSize-1);
+ }
+ }
+ if(hasFocus()){
+ g.setColor(Color.RED);
+ }else {
+ g.setColor(Color.BLACK);
+ }
+ g.setFont(font);
+ g.drawString("A", pos.x * cellSize, (pos.y + 1) * cellSize);
+ }
+ };
+ getViewport().add(worldPanel);
+ setVisible(true);
+ }
+ };
+ }
}
diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java
index 1d656e6..d68c597 100644
--- a/src/main/java/evironment/antGame/AntWorld.java
+++ b/src/main/java/evironment/antGame/AntWorld.java
@@ -35,21 +35,15 @@ public class AntWorld implements Environment{
private int tick;
private int maxEpisodeTicks;
- MainFrame gui;
public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height);
myAnt = new Ant();
- gui = new MainFrame(this, antAgent);
maxEpisodeTicks = 1000;
reset();
}
- public MainFrame getGui(){
- return gui;
- }
-
public AntWorld(){
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
}
@@ -166,7 +160,6 @@ public class AntWorld implements Environment{
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
- getGui().update(action, result);
return result;
}
@@ -216,6 +209,6 @@ public class AntWorld implements Environment{
new AntWorld(3, 3, 0.1),
new ListDiscreteActionSpace<>(AntAction.values())
);
- monteCarlo.learn(20000,5);
+ monteCarlo.learn(20000);
}
}
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
new file mode 100644
index 0000000..19311d0
--- /dev/null
+++ b/src/main/java/example/RunningAnt.java
@@ -0,0 +1,21 @@
+package example;
+
+import core.RNG;
+import core.algo.Method;
+import core.controller.RLController;
+import evironment.antGame.AntAction;
+import evironment.antGame.AntWorld;
+
+public class RunningAnt {
+ public static void main(String[] args) {
+ RNG.setSeed(1234);
+
+ RLController rl = new RLController()
+ .setEnvironment(new AntWorld(3,3,0.1))
+ .setAllowedActions(AntAction.values())
+ .setMethod(Method.MC_ONPOLICY_EGREEDY)
+ .setDelay(10)
+ .setEpisodes(1000);
+ rl.start();
+ }
+}
diff --git a/src/main/java/example/Test.java b/src/main/java/example/Test.java
new file mode 100644
index 0000000..c7ee691
--- /dev/null
+++ b/src/main/java/example/Test.java
@@ -0,0 +1,52 @@
+package example;
+
+public class Test {
+ interface Drawable{
+ void draw();
+ }
+ interface State{
+ int getInt();
+ }
+
+ static class A implements Drawable, State{
+ private int k;
+ public A(int a){
+ k = a;
+ }
+ @Override
+ public void draw() {
+ System.out.println("draw " + k);
+ }
+
+ @Override
+ public int getInt() {
+ System.out.println("getInt" + k);
+ return k;
+ }
+ }
+
+ static class B implements State{
+ @Override
+ public int getInt() {
+ return 0;
+ }
+ }
+
+ public static void main(String[] args) {
+ State state = new A(24);
+ State state2 = new B();
+ state.getInt();
+
+ System.out.println(state2 instanceof Drawable);
+ drawState(state2);
+ }
+
+ static void drawState(State s){
+ if(s instanceof Drawable){
+ Drawable d = (Drawable) s;
+ d.draw();
+ }else{
+ System.out.println("invalid");
+ }
+ }
+}