diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java index 8813ded..2fe8929 100644 --- a/src/main/java/core/RNG.java +++ b/src/main/java/core/RNG.java @@ -13,10 +13,6 @@ public class RNG { return rng; } - public static void reseed(){ - rng.setSeed(seed); - } - public static void setSeed(int seed){ RNG.seed = seed; rng.setSeed(seed); diff --git a/src/main/java/core/algo/Episodic.java b/src/main/java/core/algo/Episodic.java new file mode 100644 index 0000000..f846c84 --- /dev/null +++ b/src/main/java/core/algo/Episodic.java @@ -0,0 +1,5 @@ +package core.algo; + +public interface Episodic { + int getCurrentEpisode(); +} diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java new file mode 100644 index 0000000..7f6af1b --- /dev/null +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -0,0 +1,29 @@ +package core.algo; + +import core.DiscreteActionSpace; +import core.Environment; + +public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{ + protected int currentEpisode; + + public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) { + super(environment, actionSpace, discountFactor, delay); + } + + public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) { + super(environment, actionSpace, discountFactor); + } + + public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) { + super(environment, actionSpace, delay); + } + + public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) { + super(environment, actionSpace); + } + + @Override + public int getCurrentEpisode(){ + return currentEpisode; + } +} diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index a825bd5..c7058fb 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -9,8 +9,6 @@ import core.policy.Policy; import lombok.Getter; import lombok.Setter; -import javax.swing.*; -import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -68,4 +66,10 @@ public abstract class Learning<A extends Enum> { l.onEpisodeStart(); } } + + protected void dispatchStepEnd(){ + for(LearningListener l: learningListeners){ + l.onStepEnd(); + } + } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index 1bc1f11..a9a20bc 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -1,10 +1,9 @@ package core.algo.mc; import core.*; -import core.algo.Learning; +import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; import javafx.util.Pair; -import lombok.Setter; import java.util.*; @@ -26,11 +25,11 @@ import java.util.*; * How to encounter this problem? * @param <A> */ -public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> { +public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> { public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); - + currentEpisode = 0; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new StateActionHashTable<>(this.actionSpace); } @@ -47,8 +46,15 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> { Map<Pair<State, A>, Integer> returnCount = new HashMap<>(); for(int i = 0; i < nrOfEpisodes; ++i) { + ++currentEpisode; List<StepResult<A>> episode = new ArrayList<>(); State state = environment.reset(); + dispatchEpisodeStart(); + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } double sumOfRewards = 0; for(int j=0; j < 10; ++j){ Map<A, Double> actionValues = stateActionTable.getActionValues(state); @@ -67,6 +73,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> { } catch (InterruptedException e) { e.printStackTrace(); } + dispatchStepEnd(); } dispatchEpisodeEnd(sumOfRewards); @@ -100,4 +107,9 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> { } } } + + @Override + public int getCurrentEpisode() { + return currentEpisode; + } } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 80b4375..8ab5094 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -7,11 +7,12 @@ import core.algo.Learning; import core.algo.Method; import core.algo.mc.MonteCarloOnPolicyEGreedy; import core.gui.View; +import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; -public class RLController<A extends Enum> implements ViewListener{ +public class RLController<A extends Enum> implements ViewListener { protected Environment<A> environment; protected Learning<A> learning; protected DiscreteActionSpace<A> discreteActionSpace; @@ -19,6 +20,7 @@ public class RLController<A extends Enum> implements ViewListener{ private int delay; private int nrOfEpisodes; private Method method; + private int prevDelay; public RLController(){ } @@ -41,7 +43,7 @@ public class RLController<A extends Enum> implements ViewListener{ not using SwingUtilities here on purpose to ensure the view is fully initialized and can be passed as LearningListener. */ - view = new View<>(learning, this); + view = new View<>(learning, environment, this); learning.addListener(view); learning.learn(nrOfEpisodes); } @@ -58,10 +60,23 @@ public class RLController<A extends Enum> implements ViewListener{ @Override public void onDelayChange(int delay) { + changeLearningDelay(delay); + } + + private void changeLearningDelay(int delay){ learning.setDelay(delay); - SwingUtilities.invokeLater(() -> { - view.updateLearningInfoPanel(); - }); + SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + } + + @Override + public void onFastLearnChange(boolean fastLearn) { + view.setDrawEveryStep(!fastLearn); + if(fastLearn){ + prevDelay = learning.getDelay(); + changeLearningDelay(0); + }else{ + changeLearningDelay(prevDelay); + } } public RLController<A> setMethod(Method method){ diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index 8c67589..4ed4acb 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -1,7 +1,8 @@ package core.gui; +import core.algo.Episodic; import core.algo.Learning; -import core.controller.ViewListener; +import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; @@ -11,39 +12,62 @@ public class LearningInfoPanel extends JPanel { private JLabel policyLabel; private JLabel discountLabel; private JLabel epsilonLabel; + private JLabel episodeLabel; private JSlider epsilonSlider; private JLabel delayLabel; private JSlider delaySlider; + private JButton toggleFastLearningButton; + private boolean fastLearning; public LearningInfoPanel(Learning learning, ViewListener viewListener){ this.learning = learning; setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); policyLabel = new JLabel(); discountLabel = new JLabel(); - epsilonLabel = new JLabel(); delayLabel = new JLabel(); + if(learning instanceof Episodic){ + episodeLabel = new JLabel(); + add(episodeLabel); + } delaySlider = new JSlider(0,1000, learning.getDelay()); delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); add(policyLabel); add(discountLabel); if(learning.getPolicy() instanceof EpsilonPolicy){ + epsilonLabel = new JLabel(); epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100); epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); add(epsilonLabel); add(epsilonSlider); } + + toggleFastLearningButton = new JButton("Enable fast-learn"); + fastLearning = false; + toggleFastLearningButton.addActionListener(e->{ + fastLearning = !fastLearning; + delaySlider.setEnabled(!fastLearning); + epsilonSlider.setEnabled(!fastLearning); + viewListener.onFastLearnChange(fastLearning); + }); add(delayLabel); add(delaySlider); + add(toggleFastLearningButton); refreshLabels(); setVisible(true); } - public void refreshLabels(){ + public void refreshLabels() { policyLabel.setText("Policy: " + learning.getPolicy().getClass()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); - if(learning.getPolicy() instanceof EpsilonPolicy){ - epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon()); + if(learning instanceof Episodic){ + episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode()); + } + if (learning.getPolicy() instanceof EpsilonPolicy) { + epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon()); + epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100)); } delayLabel.setText("Delay (ms): " + learning.getDelay()); + delaySlider.setValue(learning.getDelay()); + toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning"); } } diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index d031451..5dff7e6 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -1,7 +1,8 @@ package core.gui; +import core.Environment; import core.algo.Learning; -import core.controller.ViewListener; +import core.listener.ViewListener; import core.listener.LearningListener; import lombok.Getter; import org.knowm.xchart.QuickChart; @@ -10,27 +11,31 @@ import org.knowm.xchart.XYChart; import javax.swing.*; import java.awt.*; -import java.util.ArrayList; import java.util.List; public class View<A extends Enum> implements LearningListener { private Learning<A> learning; + private Environment<A> environment; @Getter - private XYChart chart; + private XYChart rewardChart; @Getter private LearningInfoPanel learningInfoPanel; @Getter private JFrame mainFrame; + private JFrame environmentFrame; private XChartPanel<XYChart> rewardChartPanel; private ViewListener viewListener; + private boolean drawEveryStep; - public View(Learning<A> learning, ViewListener viewListener){ + public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) { this.learning = learning; + this.environment = environment; this.viewListener = viewListener; - this.initMainFrame(); + drawEveryStep = true; + SwingUtilities.invokeLater(this::initMainFrame); } - private void initMainFrame(){ + private void initMainFrame() { mainFrame = new JFrame(); mainFrame.setPreferredSize(new Dimension(1280, 720)); mainFrame.setLayout(new BorderLayout()); @@ -44,29 +49,40 @@ public class View<A extends Enum> implements LearningListener { mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); mainFrame.pack(); mainFrame.setVisible(true); + + if (environment instanceof Visualizable) { + environmentFrame = new JFrame() { + { + add(((Visualizable) environment).visualize()); + pack(); + setVisible(true); + } + }; + + } } - private void initLearningInfoPanel(){ + private void initLearningInfoPanel() { learningInfoPanel = new LearningInfoPanel(learning, viewListener); } - private void initRewardChart(){ - chart = + private void initRewardChart() { + rewardChart = QuickChart.getChart( - "Rewards per Episode", + "Sum of Rewards per Episode", "Episode", "Reward", - "randomWalk", - new double[] {0}, - new double[] {0}); - chart.getStyler().setLegendVisible(true); - chart.getStyler().setXAxisTicksVisible(true); - rewardChartPanel = new XChartPanel<>(chart); - rewardChartPanel.setPreferredSize(new Dimension(300,300)); + "rewardHistory", + new double[]{0}, + new double[]{0}); + rewardChart.getStyler().setLegendVisible(true); + rewardChart.getStyler().setXAxisTicksVisible(true); + rewardChartPanel = new XChartPanel<>(rewardChart); + rewardChartPanel.setPreferredSize(new Dimension(300, 300)); } - public void showState(Visualizable state){ - new JFrame(){ + public void showState(Visualizable state) { + new JFrame() { { JComponent stateComponent = state.visualize(); setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight())); @@ -76,25 +92,47 @@ public class View<A extends Enum> implements LearningListener { }; } - public void updateRewardGraph(List<Double> rewardHistory){ - chart.updateXYSeries("randomWalk", null, rewardHistory, null); + public void setDrawEveryStep(boolean drawEveryStep){ + this.drawEveryStep = drawEveryStep; + } + + public void updateRewardGraph(List<Double> rewardHistory) { + rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null); rewardChartPanel.revalidate(); rewardChartPanel.repaint(); } - public void updateLearningInfoPanel(){ + public void updateLearningInfoPanel() { this.learningInfoPanel.refreshLabels(); } @Override public void onEpisodeEnd(List<Double> rewardHistory) { - SwingUtilities.invokeLater(()->{ - updateRewardGraph(rewardHistory); - }); + SwingUtilities.invokeLater(() ->{ + if(drawEveryStep){ + updateRewardGraph(rewardHistory); + } + updateLearningInfoPanel(); + }); } @Override public void onEpisodeStart() { + if(drawEveryStep) { + SwingUtilities.invokeLater(this::repaintEnvironment); + } + } + @Override + public void onStepEnd() { + if(drawEveryStep){ + SwingUtilities.invokeLater(this::repaintEnvironment); + } + } + + private void repaintEnvironment(){ + if (environmentFrame != null) { + environmentFrame.repaint(); + } } } diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java index 4147897..add1fbd 100644 --- a/src/main/java/core/listener/LearningListener.java +++ b/src/main/java/core/listener/LearningListener.java @@ -5,4 +5,5 @@ import java.util.List; public interface LearningListener{ void onEpisodeEnd(List<Double> rewardHistory); void onEpisodeStart(); + void onStepEnd(); } diff --git a/src/main/java/core/controller/ViewListener.java b/src/main/java/core/listener/ViewListener.java similarity index 60% rename from src/main/java/core/controller/ViewListener.java rename to src/main/java/core/listener/ViewListener.java index 578512a..f27d7b4 100644 --- a/src/main/java/core/controller/ViewListener.java +++ b/src/main/java/core/listener/ViewListener.java @@ -1,6 +1,7 @@ -package core.controller; +package core.listener; public interface ViewListener { void onEpsilonChange(float epsilon); void onDelayChange(int delay); + void onFastLearnChange(boolean isFastLearn); } diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java index 1288aed..550889f 100644 --- a/src/main/java/core/policy/EpsilonGreedyPolicy.java +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -29,7 +29,6 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{ @Override public A chooseAction(Map<A, Double> actionValues) { - System.out.println("current epsilon " + epsilon); if(RNG.getRandom().nextFloat() < epsilon){ // Take random action return randomPolicy.chooseAction(actionValues); diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index e512ad7..ebd0de6 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -1,14 +1,15 @@ package evironment.antGame; -import core.*; -import core.algo.Learning; -import core.algo.mc.MonteCarloOnPolicyEGreedy; -import evironment.antGame.gui.MainFrame; - +import core.Environment; +import core.State; +import core.StepResultEnvironment; +import core.gui.Visualizable; +import evironment.antGame.gui.AntWorldComponent; +import javax.swing.*; import java.awt.*; -public class AntWorld implements Environment<AntAction>{ +public class AntWorld implements Environment<AntAction>, Visualizable { /** * */ @@ -204,14 +205,8 @@ public class AntWorld implements Environment<AntAction>{ return myAnt; } - public static void main(String[] args) { - RNG.setSeed(1993); - - Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>( - new AntWorld(3, 3, 0.1), - new ListDiscreteActionSpace<>(AntAction.values()), - 5 - ); - monteCarlo.learn(20000); + @Override + public JComponent visualize() { + return new AntWorldComponent(this, this.antAgent); } } diff --git a/src/main/java/evironment/antGame/gui/MainFrame.java b/src/main/java/evironment/antGame/gui/AntWorldComponent.java similarity index 57% rename from src/main/java/evironment/antGame/gui/MainFrame.java rename to src/main/java/evironment/antGame/gui/AntWorldComponent.java index c299d78..7edb62b 100644 --- a/src/main/java/evironment/antGame/gui/MainFrame.java +++ b/src/main/java/evironment/antGame/gui/AntWorldComponent.java @@ -1,21 +1,18 @@ package evironment.antGame.gui; -import core.StepResultEnvironment; -import evironment.antGame.AntAction; import evironment.antGame.AntAgent; import evironment.antGame.AntWorld; import javax.swing.*; import java.awt.*; -public class MainFrame extends JFrame { +public class AntWorldComponent extends JComponent { private AntWorld antWorld; private HistoryPanel historyPanel; - public MainFrame(AntWorld antWorld, AntAgent antAgent){ + public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){ this.antWorld = antWorld; setLayout(new BorderLayout()); - setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10); CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10); historyPanel = new HistoryPanel(); @@ -29,14 +26,11 @@ public class MainFrame extends JFrame { add(BorderLayout.CENTER, mapComponent); add(BorderLayout.SOUTH, historyPanel); - pack(); setVisible(true); } - public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){ - historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ", - antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); - - repaint(); + @Override + protected void paintComponent(Graphics g) { + super.paintComponent(g); } } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index dc22cc0..0106c30 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -8,14 +8,14 @@ import evironment.antGame.AntWorld; public class RunningAnt { public static void main(String[] args) { - RNG.setSeed(1234); + RNG.setSeed(123); RLController<AntAction> rl = new RLController<AntAction>() .setEnvironment(new AntWorld(3,3,0.1)) .setAllowedActions(AntAction.values()) .setMethod(Method.MC_ONPOLICY_EGREEDY) .setDelay(10) - .setEpisodes(10000); + .setEpisodes(100000); rl.start(); } }