From e0160ca1dfbc33e237c54499b6fdb5b916f99970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Wed, 18 Dec 2019 16:48:24 +0100 Subject: [PATCH] adopt MVC pattern and add real time graph interface --- .idea/compiler.xml | 2 +- build.gradle | 3 + src/main/java/core/DiscreteActionSpace.java | 2 +- src/main/java/core/LearningConfig.java | 7 ++ .../java/core/ListDiscreteActionSpace.java | 2 +- src/main/java/core/algo/Learning.java | 46 +++++++- .../algo/MC/MonteCarloOnPolicyEGreedy.java | 12 +-- src/main/java/core/algo/Method.java | 5 + .../java/core/controller/RLController.java | 81 ++++++++++++++ .../java/core/controller/ViewListener.java | 6 ++ src/main/java/core/gui/LearningInfoPanel.java | 41 +++++++ src/main/java/core/gui/View.java | 102 ++++++++++++++++++ src/main/java/core/gui/Visualizable.java | 7 ++ .../java/core/listener/LearningListener.java | 6 ++ .../java/evironment/antGame/AntState.java | 75 +++++++++++-- .../java/evironment/antGame/AntWorld.java | 9 +- src/main/java/example/RunningAnt.java | 21 ++++ src/main/java/example/Test.java | 52 +++++++++ 18 files changed, 450 insertions(+), 29 deletions(-) create mode 100644 src/main/java/core/LearningConfig.java create mode 100644 src/main/java/core/algo/Method.java create mode 100644 src/main/java/core/controller/RLController.java create mode 100644 src/main/java/core/controller/ViewListener.java create mode 100644 src/main/java/core/gui/LearningInfoPanel.java create mode 100644 src/main/java/core/gui/View.java create mode 100644 src/main/java/core/gui/Visualizable.java create mode 100644 src/main/java/core/listener/LearningListener.java create mode 100644 src/main/java/example/RunningAnt.java create mode 100644 src/main/java/example/Test.java diff --git a/.idea/compiler.xml b/.idea/compiler.xml index a1757ae..95a88ae 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -5,4 +5,4 @@ - \ No newline at end of file + diff --git a/build.gradle b/build.gradle index a28b89e..7ea77b1 100644 --- a/build.gradle +++ b/build.gradle @@ -13,6 +13,9 @@ repositories { } dependencies { + // https://mvnrepository.com/artifact/org.jfree/jfreechart + // https://mvnrepository.com/artifact/org.knowm.xchart/xchart + compile group: 'org.knowm.xchart', name: 'xchart', version: '3.2.2' testCompile group: 'junit', name: 'junit', version: '4.12' compileOnly 'org.projectlombok:lombok:1.18.10' annotationProcessor 'org.projectlombok:lombok:1.18.10' diff --git a/src/main/java/core/DiscreteActionSpace.java b/src/main/java/core/DiscreteActionSpace.java index a6b38fe..a5caf2b 100644 --- a/src/main/java/core/DiscreteActionSpace.java +++ b/src/main/java/core/DiscreteActionSpace.java @@ -3,7 +3,7 @@ package core; import java.util.List; public interface DiscreteActionSpace { - int getNumberOfAction(); + int getNumberOfActions(); void addAction(A a); void addActions(A... as); List getAllActions(); diff --git a/src/main/java/core/LearningConfig.java b/src/main/java/core/LearningConfig.java new file mode 100644 index 0000000..916de16 --- /dev/null +++ b/src/main/java/core/LearningConfig.java @@ -0,0 +1,7 @@ +package core; + +public class LearningConfig { + public static final int DEFAULT_DELAY = 1; + public static final float DEFAULT_EPSILON = 0.1f; + public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f; +} diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java index 76babaf..42de87a 100644 --- a/src/main/java/core/ListDiscreteActionSpace.java +++ b/src/main/java/core/ListDiscreteActionSpace.java @@ -32,7 +32,7 @@ public class ListDiscreteActionSpace implements DiscreteActionSp } @Override - public int getNumberOfAction(){ + public int getNumberOfActions(){ return actions.size(); } } diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 8d71a78..58a285a 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -2,26 +2,62 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; +import core.LearningConfig; import core.StateActionTable; +import core.listener.LearningListener; import core.policy.Policy; +import lombok.Getter; +import lombok.Setter; +import javax.swing.*; +import java.util.HashSet; +import java.util.Set; + +@Getter public abstract class Learning { protected Policy policy; protected DiscreteActionSpace actionSpace; protected StateActionTable stateActionTable; protected Environment environment; protected float discountFactor; + @Setter protected float epsilon; + protected Set learningListeners; + @Setter + protected int delay; - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){ + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay){ this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; this.epsilon = epsilon; - } - public Learning(Environment environment, DiscreteActionSpace actionSpace){ - this(environment, actionSpace, 1.0f, 0.1f); + this.delay = delay; + learningListeners = new HashSet<>(); } - public abstract void learn(int nrOfEpisodes, int delay); + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){ + this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY); + } + + public Learning(Environment environment, DiscreteActionSpace actionSpace){ + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY); + } + + public abstract void learn(int nrOfEpisodes); + + public void addListener(LearningListener learningListener){ + learningListeners.add(learningListener); + } + + protected void dispatchEpisodeEnd(double sum){ + for(LearningListener l: learningListeners) { + l.onEpisodeEnd(sum); + } + } + + protected void dispatchEpisodeStart(){ + for(LearningListener l: learningListeners){ + l.onEpisodeStart(); + } + } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index 54450e8..d608b80 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -34,22 +34,21 @@ public class MonteCarloOnPolicyEGreedy extends Learning { } @Override - public void learn(int nrOfEpisodes, int delay) { + public void learn(int nrOfEpisodes) { Map, Double> returnSum = new HashMap<>(); Map, Integer> returnCount = new HashMap<>(); - State startingState = environment.reset(); for(int i = 0; i < nrOfEpisodes; ++i) { List> episode = new ArrayList<>(); State state = environment.reset(); - double rewardSum = 0; + double sumOfRewards = 0; for(int j=0; j < 10; ++j){ Map actionValues = stateActionTable.getActionValues(state); A chosenAction = policy.chooseAction(actionValues); StepResultEnvironment envResult = environment.step(chosenAction); State nextState = envResult.getState(); - rewardSum += envResult.getReward(); + sumOfRewards += envResult.getReward(); episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); if(envResult.isDone()) break; @@ -57,13 +56,14 @@ public class MonteCarloOnPolicyEGreedy extends Learning { state = nextState; try { - Thread.sleep(1); + Thread.sleep(delay); } catch (InterruptedException e) { e.printStackTrace(); } } - System.out.printf("Episode %d \t Reward: %f \n", i, rewardSum); + dispatchEpisodeEnd(sumOfRewards); + System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards); Set> stateActionPairs = new HashSet<>(); for(StepResult sr: episode){ diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java new file mode 100644 index 0000000..b2da8cb --- /dev/null +++ b/src/main/java/core/algo/Method.java @@ -0,0 +1,5 @@ +package core.algo; + +public enum Method { + MC_ONPOLICY_EGREEDY, TD_ONPOLICY +} diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java new file mode 100644 index 0000000..c65e62d --- /dev/null +++ b/src/main/java/core/controller/RLController.java @@ -0,0 +1,81 @@ +package core.controller; + +import core.DiscreteActionSpace; +import core.Environment; +import core.ListDiscreteActionSpace; +import core.algo.Learning; +import core.algo.Method; +import core.algo.mc.MonteCarloOnPolicyEGreedy; +import core.gui.View; + +import javax.swing.*; +import java.util.Optional; + +public class RLController implements ViewListener{ + protected Environment environment; + protected Learning learning; + protected DiscreteActionSpace discreteActionSpace; + protected View view; + private int delay; + private int nrOfEpisodes; + private Method method; + + public RLController(){ + } + + public void start(){ + if(environment == null || discreteActionSpace == null || method == null){ + throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()"); + } + + switch (method){ + case MC_ONPOLICY_EGREEDY: + learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace); + break; + case TD_ONPOLICY: + break; + default: + throw new RuntimeException("Undefined method"); + } + SwingUtilities.invokeLater(() ->{ + view = new View<>(learning, this); + learning.addListener(view); + }); + learning.learn(nrOfEpisodes); + } + + @Override + public void onEpsilonChange(float epsilon) { + learning.setEpsilon(epsilon); + SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + } + + @Override + public void onDelayChange(int delay) { + } + + public RLController setMethod(Method method){ + this.method = method; + return this; + } + public RLController setEnvironment(Environment environment){ + this.environment = environment; + return this; + } + @SafeVarargs + public final RLController setAllowedActions(A... actions){ + this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); + return this; + } + + public RLController setDelay(int delay){ + this.delay = delay; + return this; + } + + public RLController setEpisodes(int nrOfEpisodes){ + this.nrOfEpisodes = nrOfEpisodes; + return this; + } + +} diff --git a/src/main/java/core/controller/ViewListener.java b/src/main/java/core/controller/ViewListener.java new file mode 100644 index 0000000..578512a --- /dev/null +++ b/src/main/java/core/controller/ViewListener.java @@ -0,0 +1,6 @@ +package core.controller; + +public interface ViewListener { + void onEpsilonChange(float epsilon); + void onDelayChange(int delay); +} diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java new file mode 100644 index 0000000..cbbd6ef --- /dev/null +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -0,0 +1,41 @@ +package core.gui; + +import core.algo.Learning; +import core.controller.ViewListener; + +import javax.swing.*; + +public class LearningInfoPanel extends JPanel { + private Learning learning; + private JLabel policyLabel; + private JLabel discountLabel; + private JLabel epsilonLabel; + private JSlider epsilonSlider; + private JSlider delaySlider; + + public LearningInfoPanel(Learning learning, ViewListener viewListener){ + this.learning = learning; + setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); + policyLabel = new JLabel(); + discountLabel = new JLabel(); + epsilonLabel = new JLabel(); + epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100)); + epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); + add(policyLabel); + add(discountLabel); + add(epsilonLabel); + add(epsilonSlider); + refreshLabels(); + setVisible(true); + } + + public void refreshLabels(){ + policyLabel.setText("Policy: " + learning.getPolicy().getClass()); + discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); + epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon()); + } + + protected JSlider getEpsilonSlider(){ + return epsilonSlider; + } +} diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java new file mode 100644 index 0000000..8939ac2 --- /dev/null +++ b/src/main/java/core/gui/View.java @@ -0,0 +1,102 @@ +package core.gui; + +import core.algo.Learning; +import core.controller.ViewListener; +import core.listener.LearningListener; +import lombok.Getter; +import org.knowm.xchart.QuickChart; +import org.knowm.xchart.XChartPanel; +import org.knowm.xchart.XYChart; + +import javax.swing.*; +import java.awt.*; +import java.util.ArrayList; +import java.util.List; + +public class View implements LearningListener { + private Learning learning; + @Getter + private XYChart chart; + @Getter + private LearningInfoPanel learningInfoPanel; + @Getter + private JFrame mainFrame; + private XChartPanel rewardChartPanel; + private ViewListener viewListener; + private List rewardHistory; + + public View(Learning learning, ViewListener viewListener){ + this.learning = learning; + this.viewListener = viewListener; + rewardHistory = new ArrayList<>(); + this.initMainFrame(); + } + + private void initMainFrame(){ + mainFrame = new JFrame(); + mainFrame.setPreferredSize(new Dimension(1280, 720)); + mainFrame.setLayout(new BorderLayout()); + + initLearningInfoPanel(); + initRewardChart(); + + mainFrame.add(BorderLayout.WEST, learningInfoPanel); + mainFrame.add(BorderLayout.CENTER, rewardChartPanel); + + mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); + mainFrame.pack(); + mainFrame.setVisible(true); + } + + private void initLearningInfoPanel(){ + learningInfoPanel = new LearningInfoPanel(learning, viewListener); + } + + private void initRewardChart(){ + chart = + QuickChart.getChart( + "Rewards per Episode", + "Episode", + "Reward", + "randomWalk", + new double[] {0}, + new double[] {0}); + chart.getStyler().setLegendVisible(true); + chart.getStyler().setXAxisTicksVisible(true); + rewardChartPanel = new XChartPanel<>(chart); + rewardChartPanel.setPreferredSize(new Dimension(300,300)); + } + + public void showState(Visualizable state){ + new JFrame(){ + { + JComponent stateComponent = state.visualize(); + setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight())); + add(stateComponent); + setVisible(true); + } + }; + } + + public void updateRewardGraph(double recentReward){ + rewardHistory.add(recentReward); + chart.updateXYSeries("randomWalk", null, rewardHistory, null); + rewardChartPanel.revalidate(); + rewardChartPanel.repaint(); + } + + public void updateLearningInfoPanel(){ + this.learningInfoPanel.refreshLabels(); + } + + + @Override + public void onEpisodeEnd(double sumOfRewards) { + SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards)); + } + + @Override + public void onEpisodeStart() { + + } +} diff --git a/src/main/java/core/gui/Visualizable.java b/src/main/java/core/gui/Visualizable.java new file mode 100644 index 0000000..e144a40 --- /dev/null +++ b/src/main/java/core/gui/Visualizable.java @@ -0,0 +1,7 @@ +package core.gui; + +import javax.swing.*; + +public interface Visualizable { + JComponent visualize(); +} diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java new file mode 100644 index 0000000..5a9d287 --- /dev/null +++ b/src/main/java/core/listener/LearningListener.java @@ -0,0 +1,6 @@ +package core.listener; + +public interface LearningListener{ + void onEpisodeEnd(double sumOfRewards); + void onEpisodeStart(); +} diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index d1a31b7..8c5bda7 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -1,7 +1,10 @@ package evironment.antGame; import core.State; +import core.gui.Visualizable; +import evironment.antGame.gui.CellColor; +import javax.swing.*; import java.awt.*; import java.util.Arrays; @@ -10,7 +13,7 @@ import java.util.Arrays; * Essentially a snapshot of the current Ant Agent * and therefor has to be deep copied */ -public class AntState implements State { +public class AntState implements State, Visualizable { private final Cell[][] knownWorld; private final Point pos; private final boolean hasFood; @@ -29,12 +32,12 @@ public class AntState implements State { int unknown = 0; int diff = 0; - for (int i = 0; i < knownWorld.length; i++) { - for (int j = 0; j < knownWorld[i].length; j++) { - if(knownWorld[i][j].getType() == CellType.UNKNOWN){ + for (Cell[] cells : knownWorld) { + for (Cell cell : cells) { + if (cell.getType() == CellType.UNKNOWN) { unknown += 1; - }else{ - diff +=1; + } else { + diff += 1; } } } @@ -62,7 +65,7 @@ public class AntState implements State { @Override public String toString(){ return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld)); -} + } //TODO: make this a utility function to generate hash Code based upon 2 prime numbers @Override @@ -89,4 +92,62 @@ public class AntState implements State { } return super.equals(obj); } + + @Override + public JComponent visualize() { + return new JScrollPane() { + private int cellSize; + private final int paneWidth = 500; + private final int paneHeight = 500; + private Font font; + { + setPreferredSize(new Dimension(paneWidth, paneHeight)); + cellSize = (paneWidth- knownWorld.length) /knownWorld.length; + font = new Font("plain", Font.BOLD, cellSize); + JPanel worldPanel = new JPanel(){ + { + setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize)); + setVisible(true); + + addMouseWheelListener(e -> { + if(e.getWheelRotation() > 0){ + cellSize -= 1; + }else { + cellSize += 1; + } + font = new Font("plain", Font.BOLD, cellSize); + setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize)); + revalidate(); + repaint(); + }); + } + + @Override + public void paintComponent(Graphics g) { + super.paintComponent(g); + for (int i = 0; i < knownWorld.length; i++) { + for (int j = 0; j < knownWorld[0].length; j++) { + g.setColor(Color.BLACK); + g.drawRect(i*cellSize, j*cellSize, cellSize, cellSize); + g.setColor(CellColor.map.get(knownWorld[i][j].getType())); + if(knownWorld[i][j].getFood() > 0){ + g.setColor(Color.YELLOW); + } + g.fillRect(i*cellSize+1, j*cellSize+1, cellSize -1, cellSize-1); + } + } + if(hasFocus()){ + g.setColor(Color.RED); + }else { + g.setColor(Color.BLACK); + } + g.setFont(font); + g.drawString("A", pos.x * cellSize, (pos.y + 1) * cellSize); + } + }; + getViewport().add(worldPanel); + setVisible(true); + } + }; + } } diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index 1d656e6..d68c597 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -35,21 +35,15 @@ public class AntWorld implements Environment{ private int tick; private int maxEpisodeTicks; - MainFrame gui; public AntWorld(int width, int height, double foodDensity){ grid = new Grid(width, height, foodDensity); antAgent = new AntAgent(width, height); myAnt = new Ant(); - gui = new MainFrame(this, antAgent); maxEpisodeTicks = 1000; reset(); } - public MainFrame getGui(){ - return gui; - } - public AntWorld(){ this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY); } @@ -166,7 +160,6 @@ public class AntWorld implements Environment{ StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info); - getGui().update(action, result); return result; } @@ -216,6 +209,6 @@ public class AntWorld implements Environment{ new AntWorld(3, 3, 0.1), new ListDiscreteActionSpace<>(AntAction.values()) ); - monteCarlo.learn(20000,5); + monteCarlo.learn(20000); } } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java new file mode 100644 index 0000000..19311d0 --- /dev/null +++ b/src/main/java/example/RunningAnt.java @@ -0,0 +1,21 @@ +package example; + +import core.RNG; +import core.algo.Method; +import core.controller.RLController; +import evironment.antGame.AntAction; +import evironment.antGame.AntWorld; + +public class RunningAnt { + public static void main(String[] args) { + RNG.setSeed(1234); + + RLController rl = new RLController() + .setEnvironment(new AntWorld(3,3,0.1)) + .setAllowedActions(AntAction.values()) + .setMethod(Method.MC_ONPOLICY_EGREEDY) + .setDelay(10) + .setEpisodes(1000); + rl.start(); + } +} diff --git a/src/main/java/example/Test.java b/src/main/java/example/Test.java new file mode 100644 index 0000000..c7ee691 --- /dev/null +++ b/src/main/java/example/Test.java @@ -0,0 +1,52 @@ +package example; + +public class Test { + interface Drawable{ + void draw(); + } + interface State{ + int getInt(); + } + + static class A implements Drawable, State{ + private int k; + public A(int a){ + k = a; + } + @Override + public void draw() { + System.out.println("draw " + k); + } + + @Override + public int getInt() { + System.out.println("getInt" + k); + return k; + } + } + + static class B implements State{ + @Override + public int getInt() { + return 0; + } + } + + public static void main(String[] args) { + State state = new A(24); + State state2 = new B(); + state.getInt(); + + System.out.println(state2 instanceof Drawable); + drawState(state2); + } + + static void drawState(State s){ + if(s instanceof Drawable){ + Drawable d = (Drawable) s; + d.draw(); + }else{ + System.out.println("invalid"); + } + } +}