From 34e7e3fdd6ceb861fa48f1f9f0dffb4e9bdbd1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sat, 21 Dec 2019 00:23:09 +0100 Subject: [PATCH] distinguish learning and episodic learning, enable fast-learning without drawing every step to reduce lag - repainting every step on no time delay will certainly freeze the app, so "fast-learning" will disable it, only refreshing current episode label - Added new abstract class "Episodic Learning". Maybe just use an interface instead?! Important because TD learning is not episodic, needs another way to represent the rewards received (maybe mean of last X rewards or sth) - Opening two JFrames, one with learning infos and one with environment --- src/main/java/core/RNG.java | 4 - src/main/java/core/algo/Episodic.java | 5 ++ src/main/java/core/algo/EpisodicLearning.java | 29 ++++++ src/main/java/core/algo/Learning.java | 8 +- .../algo/MC/MonteCarloOnPolicyEGreedy.java | 20 ++++- .../java/core/controller/RLController.java | 25 ++++-- src/main/java/core/gui/LearningInfoPanel.java | 34 +++++-- src/main/java/core/gui/View.java | 88 +++++++++++++------ .../java/core/listener/LearningListener.java | 1 + .../ViewListener.java | 3 +- .../java/core/policy/EpsilonGreedyPolicy.java | 1 - .../java/evironment/antGame/AntWorld.java | 25 +++--- ...{MainFrame.java => AntWorldComponent.java} | 16 ++-- src/main/java/example/RunningAnt.java | 4 +- 14 files changed, 188 insertions(+), 75 deletions(-) create mode 100644 src/main/java/core/algo/Episodic.java create mode 100644 src/main/java/core/algo/EpisodicLearning.java rename src/main/java/core/{controller => listener}/ViewListener.java (60%) rename src/main/java/evironment/antGame/gui/{MainFrame.java => AntWorldComponent.java} (57%) diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java index 8813ded..2fe8929 100644 --- a/src/main/java/core/RNG.java +++ b/src/main/java/core/RNG.java @@ -13,10 +13,6 @@ public class RNG { return rng; } - public static void reseed(){ - rng.setSeed(seed); - } - public static void setSeed(int seed){ RNG.seed = seed; rng.setSeed(seed); diff --git a/src/main/java/core/algo/Episodic.java b/src/main/java/core/algo/Episodic.java new file mode 100644 index 0000000..f846c84 --- /dev/null +++ b/src/main/java/core/algo/Episodic.java @@ -0,0 +1,5 @@ +package core.algo; + +public interface Episodic { + int getCurrentEpisode(); +} diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java new file mode 100644 index 0000000..7f6af1b --- /dev/null +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -0,0 +1,29 @@ +package core.algo; + +import core.DiscreteActionSpace; +import core.Environment; + +public abstract class EpisodicLearning extends Learning implements Episodic{ + protected int currentEpisode; + + public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { + super(environment, actionSpace, discountFactor, delay); + } + + public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { + super(environment, actionSpace, discountFactor); + } + + public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) { + super(environment, actionSpace, delay); + } + + public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) { + super(environment, actionSpace); + } + + @Override + public int getCurrentEpisode(){ + return currentEpisode; + } +} diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index a825bd5..c7058fb 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -9,8 +9,6 @@ import core.policy.Policy; import lombok.Getter; import lombok.Setter; -import javax.swing.*; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -68,4 +66,10 @@ public abstract class Learning { l.onEpisodeStart(); } } + + protected void dispatchStepEnd(){ + for(LearningListener l: learningListeners){ + l.onStepEnd(); + } + } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index 1bc1f11..a9a20bc 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -1,10 +1,9 @@ package core.algo.mc; import core.*; -import core.algo.Learning; +import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; import javafx.util.Pair; -import lombok.Setter; import java.util.*; @@ -26,11 +25,11 @@ import java.util.*; * How to encounter this problem? * @param */ -public class MonteCarloOnPolicyEGreedy extends Learning { +public class MonteCarloOnPolicyEGreedy extends EpisodicLearning { public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); - + currentEpisode = 0; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new StateActionHashTable<>(this.actionSpace); } @@ -47,8 +46,15 @@ public class MonteCarloOnPolicyEGreedy extends Learning { Map, Integer> returnCount = new HashMap<>(); for(int i = 0; i < nrOfEpisodes; ++i) { + ++currentEpisode; List> episode = new ArrayList<>(); State state = environment.reset(); + dispatchEpisodeStart(); + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } double sumOfRewards = 0; for(int j=0; j < 10; ++j){ Map actionValues = stateActionTable.getActionValues(state); @@ -67,6 +73,7 @@ public class MonteCarloOnPolicyEGreedy extends Learning { } catch (InterruptedException e) { e.printStackTrace(); } + dispatchStepEnd(); } dispatchEpisodeEnd(sumOfRewards); @@ -100,4 +107,9 @@ public class MonteCarloOnPolicyEGreedy extends Learning { } } } + + @Override + public int getCurrentEpisode() { + return currentEpisode; + } } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 80b4375..8ab5094 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -7,11 +7,12 @@ import core.algo.Learning; import core.algo.Method; import core.algo.mc.MonteCarloOnPolicyEGreedy; import core.gui.View; +import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; -public class RLController implements ViewListener{ +public class RLController implements ViewListener { protected Environment environment; protected Learning learning; protected DiscreteActionSpace discreteActionSpace; @@ -19,6 +20,7 @@ public class RLController implements ViewListener{ private int delay; private int nrOfEpisodes; private Method method; + private int prevDelay; public RLController(){ } @@ -41,7 +43,7 @@ public class RLController implements ViewListener{ not using SwingUtilities here on purpose to ensure the view is fully initialized and can be passed as LearningListener. */ - view = new View<>(learning, this); + view = new View<>(learning, environment, this); learning.addListener(view); learning.learn(nrOfEpisodes); } @@ -58,10 +60,23 @@ public class RLController implements ViewListener{ @Override public void onDelayChange(int delay) { + changeLearningDelay(delay); + } + + private void changeLearningDelay(int delay){ learning.setDelay(delay); - SwingUtilities.invokeLater(() -> { - view.updateLearningInfoPanel(); - }); + SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + } + + @Override + public void onFastLearnChange(boolean fastLearn) { + view.setDrawEveryStep(!fastLearn); + if(fastLearn){ + prevDelay = learning.getDelay(); + changeLearningDelay(0); + }else{ + changeLearningDelay(prevDelay); + } } public RLController setMethod(Method method){ diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index 8c67589..4ed4acb 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -1,7 +1,8 @@ package core.gui; +import core.algo.Episodic; import core.algo.Learning; -import core.controller.ViewListener; +import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; @@ -11,39 +12,62 @@ public class LearningInfoPanel extends JPanel { private JLabel policyLabel; private JLabel discountLabel; private JLabel epsilonLabel; + private JLabel episodeLabel; private JSlider epsilonSlider; private JLabel delayLabel; private JSlider delaySlider; + private JButton toggleFastLearningButton; + private boolean fastLearning; public LearningInfoPanel(Learning learning, ViewListener viewListener){ this.learning = learning; setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); policyLabel = new JLabel(); discountLabel = new JLabel(); - epsilonLabel = new JLabel(); delayLabel = new JLabel(); + if(learning instanceof Episodic){ + episodeLabel = new JLabel(); + add(episodeLabel); + } delaySlider = new JSlider(0,1000, learning.getDelay()); delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); add(policyLabel); add(discountLabel); if(learning.getPolicy() instanceof EpsilonPolicy){ + epsilonLabel = new JLabel(); epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100); epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); add(epsilonLabel); add(epsilonSlider); } + + toggleFastLearningButton = new JButton("Enable fast-learn"); + fastLearning = false; + toggleFastLearningButton.addActionListener(e->{ + fastLearning = !fastLearning; + delaySlider.setEnabled(!fastLearning); + epsilonSlider.setEnabled(!fastLearning); + viewListener.onFastLearnChange(fastLearning); + }); add(delayLabel); add(delaySlider); + add(toggleFastLearningButton); refreshLabels(); setVisible(true); } - public void refreshLabels(){ + public void refreshLabels() { policyLabel.setText("Policy: " + learning.getPolicy().getClass()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); - if(learning.getPolicy() instanceof EpsilonPolicy){ - epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon()); + if(learning instanceof Episodic){ + episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode()); + } + if (learning.getPolicy() instanceof EpsilonPolicy) { + epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon()); + epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100)); } delayLabel.setText("Delay (ms): " + learning.getDelay()); + delaySlider.setValue(learning.getDelay()); + toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning"); } } diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index d031451..5dff7e6 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -1,7 +1,8 @@ package core.gui; +import core.Environment; import core.algo.Learning; -import core.controller.ViewListener; +import core.listener.ViewListener; import core.listener.LearningListener; import lombok.Getter; import org.knowm.xchart.QuickChart; @@ -10,27 +11,31 @@ import org.knowm.xchart.XYChart; import javax.swing.*; import java.awt.*; -import java.util.ArrayList; import java.util.List; public class View implements LearningListener { private Learning learning; + private Environment environment; @Getter - private XYChart chart; + private XYChart rewardChart; @Getter private LearningInfoPanel learningInfoPanel; @Getter private JFrame mainFrame; + private JFrame environmentFrame; private XChartPanel rewardChartPanel; private ViewListener viewListener; + private boolean drawEveryStep; - public View(Learning learning, ViewListener viewListener){ + public View(Learning learning, Environment environment, ViewListener viewListener) { this.learning = learning; + this.environment = environment; this.viewListener = viewListener; - this.initMainFrame(); + drawEveryStep = true; + SwingUtilities.invokeLater(this::initMainFrame); } - private void initMainFrame(){ + private void initMainFrame() { mainFrame = new JFrame(); mainFrame.setPreferredSize(new Dimension(1280, 720)); mainFrame.setLayout(new BorderLayout()); @@ -44,29 +49,40 @@ public class View implements LearningListener { mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); mainFrame.pack(); mainFrame.setVisible(true); + + if (environment instanceof Visualizable) { + environmentFrame = new JFrame() { + { + add(((Visualizable) environment).visualize()); + pack(); + setVisible(true); + } + }; + + } } - private void initLearningInfoPanel(){ + private void initLearningInfoPanel() { learningInfoPanel = new LearningInfoPanel(learning, viewListener); } - private void initRewardChart(){ - chart = + private void initRewardChart() { + rewardChart = QuickChart.getChart( - "Rewards per Episode", + "Sum of Rewards per Episode", "Episode", "Reward", - "randomWalk", - new double[] {0}, - new double[] {0}); - chart.getStyler().setLegendVisible(true); - chart.getStyler().setXAxisTicksVisible(true); - rewardChartPanel = new XChartPanel<>(chart); - rewardChartPanel.setPreferredSize(new Dimension(300,300)); + "rewardHistory", + new double[]{0}, + new double[]{0}); + rewardChart.getStyler().setLegendVisible(true); + rewardChart.getStyler().setXAxisTicksVisible(true); + rewardChartPanel = new XChartPanel<>(rewardChart); + rewardChartPanel.setPreferredSize(new Dimension(300, 300)); } - public void showState(Visualizable state){ - new JFrame(){ + public void showState(Visualizable state) { + new JFrame() { { JComponent stateComponent = state.visualize(); setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight())); @@ -76,25 +92,47 @@ public class View implements LearningListener { }; } - public void updateRewardGraph(List rewardHistory){ - chart.updateXYSeries("randomWalk", null, rewardHistory, null); + public void setDrawEveryStep(boolean drawEveryStep){ + this.drawEveryStep = drawEveryStep; + } + + public void updateRewardGraph(List rewardHistory) { + rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null); rewardChartPanel.revalidate(); rewardChartPanel.repaint(); } - public void updateLearningInfoPanel(){ + public void updateLearningInfoPanel() { this.learningInfoPanel.refreshLabels(); } @Override public void onEpisodeEnd(List rewardHistory) { - SwingUtilities.invokeLater(()->{ - updateRewardGraph(rewardHistory); - }); + SwingUtilities.invokeLater(() ->{ + if(drawEveryStep){ + updateRewardGraph(rewardHistory); + } + updateLearningInfoPanel(); + }); } @Override public void onEpisodeStart() { + if(drawEveryStep) { + SwingUtilities.invokeLater(this::repaintEnvironment); + } + } + @Override + public void onStepEnd() { + if(drawEveryStep){ + SwingUtilities.invokeLater(this::repaintEnvironment); + } + } + + private void repaintEnvironment(){ + if (environmentFrame != null) { + environmentFrame.repaint(); + } } } diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java index 4147897..add1fbd 100644 --- a/src/main/java/core/listener/LearningListener.java +++ b/src/main/java/core/listener/LearningListener.java @@ -5,4 +5,5 @@ import java.util.List; public interface LearningListener{ void onEpisodeEnd(List rewardHistory); void onEpisodeStart(); + void onStepEnd(); } diff --git a/src/main/java/core/controller/ViewListener.java b/src/main/java/core/listener/ViewListener.java similarity index 60% rename from src/main/java/core/controller/ViewListener.java rename to src/main/java/core/listener/ViewListener.java index 578512a..f27d7b4 100644 --- a/src/main/java/core/controller/ViewListener.java +++ b/src/main/java/core/listener/ViewListener.java @@ -1,6 +1,7 @@ -package core.controller; +package core.listener; public interface ViewListener { void onEpsilonChange(float epsilon); void onDelayChange(int delay); + void onFastLearnChange(boolean isFastLearn); } diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java index 1288aed..550889f 100644 --- a/src/main/java/core/policy/EpsilonGreedyPolicy.java +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -29,7 +29,6 @@ public class EpsilonGreedyPolicy implements EpsilonPolicy{ @Override public A chooseAction(Map actionValues) { - System.out.println("current epsilon " + epsilon); if(RNG.getRandom().nextFloat() < epsilon){ // Take random action return randomPolicy.chooseAction(actionValues); diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index e512ad7..ebd0de6 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -1,14 +1,15 @@ package evironment.antGame; -import core.*; -import core.algo.Learning; -import core.algo.mc.MonteCarloOnPolicyEGreedy; -import evironment.antGame.gui.MainFrame; - +import core.Environment; +import core.State; +import core.StepResultEnvironment; +import core.gui.Visualizable; +import evironment.antGame.gui.AntWorldComponent; +import javax.swing.*; import java.awt.*; -public class AntWorld implements Environment{ +public class AntWorld implements Environment, Visualizable { /** * */ @@ -204,14 +205,8 @@ public class AntWorld implements Environment{ return myAnt; } - public static void main(String[] args) { - RNG.setSeed(1993); - - Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>( - new AntWorld(3, 3, 0.1), - new ListDiscreteActionSpace<>(AntAction.values()), - 5 - ); - monteCarlo.learn(20000); + @Override + public JComponent visualize() { + return new AntWorldComponent(this, this.antAgent); } } diff --git a/src/main/java/evironment/antGame/gui/MainFrame.java b/src/main/java/evironment/antGame/gui/AntWorldComponent.java similarity index 57% rename from src/main/java/evironment/antGame/gui/MainFrame.java rename to src/main/java/evironment/antGame/gui/AntWorldComponent.java index c299d78..7edb62b 100644 --- a/src/main/java/evironment/antGame/gui/MainFrame.java +++ b/src/main/java/evironment/antGame/gui/AntWorldComponent.java @@ -1,21 +1,18 @@ package evironment.antGame.gui; -import core.StepResultEnvironment; -import evironment.antGame.AntAction; import evironment.antGame.AntAgent; import evironment.antGame.AntWorld; import javax.swing.*; import java.awt.*; -public class MainFrame extends JFrame { +public class AntWorldComponent extends JComponent { private AntWorld antWorld; private HistoryPanel historyPanel; - public MainFrame(AntWorld antWorld, AntAgent antAgent){ + public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){ this.antWorld = antWorld; setLayout(new BorderLayout()); - setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10); CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10); historyPanel = new HistoryPanel(); @@ -29,14 +26,11 @@ public class MainFrame extends JFrame { add(BorderLayout.CENTER, mapComponent); add(BorderLayout.SOUTH, historyPanel); - pack(); setVisible(true); } - public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){ - historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ", - antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); - - repaint(); + @Override + protected void paintComponent(Graphics g) { + super.paintComponent(g); } } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index dc22cc0..0106c30 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -8,14 +8,14 @@ import evironment.antGame.AntWorld; public class RunningAnt { public static void main(String[] args) { - RNG.setSeed(1234); + RNG.setSeed(123); RLController rl = new RLController() .setEnvironment(new AntWorld(3,3,0.1)) .setAllowedActions(AntAction.values()) .setMethod(Method.MC_ONPOLICY_EGREEDY) .setDelay(10) - .setEpisodes(10000); + .setEpisodes(100000); rl.start(); } }