diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java
index 8813ded..2fe8929 100644
--- a/src/main/java/core/RNG.java
+++ b/src/main/java/core/RNG.java
@@ -13,10 +13,6 @@ public class RNG {
return rng;
}
- public static void reseed(){
- rng.setSeed(seed);
- }
-
public static void setSeed(int seed){
RNG.seed = seed;
rng.setSeed(seed);
diff --git a/src/main/java/core/algo/Episodic.java b/src/main/java/core/algo/Episodic.java
new file mode 100644
index 0000000..f846c84
--- /dev/null
+++ b/src/main/java/core/algo/Episodic.java
@@ -0,0 +1,5 @@
+package core.algo;
+
+public interface Episodic {
+ int getCurrentEpisode();
+}
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
new file mode 100644
index 0000000..7f6af1b
--- /dev/null
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -0,0 +1,29 @@
+package core.algo;
+
+import core.DiscreteActionSpace;
+import core.Environment;
+
+public abstract class EpisodicLearning extends Learning implements Episodic{
+ protected int currentEpisode;
+
+ public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
+ super(environment, actionSpace, discountFactor, delay);
+ }
+
+ public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
+ super(environment, actionSpace, discountFactor);
+ }
+
+ public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) {
+ super(environment, actionSpace, delay);
+ }
+
+ public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) {
+ super(environment, actionSpace);
+ }
+
+ @Override
+ public int getCurrentEpisode(){
+ return currentEpisode;
+ }
+}
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index a825bd5..c7058fb 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -9,8 +9,6 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
-import javax.swing.*;
-import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -68,4 +66,10 @@ public abstract class Learning {
l.onEpisodeStart();
}
}
+
+ protected void dispatchStepEnd(){
+ for(LearningListener l: learningListeners){
+ l.onStepEnd();
+ }
+ }
}
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
index 1bc1f11..a9a20bc 100644
--- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
+++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
@@ -1,10 +1,9 @@
package core.algo.mc;
import core.*;
-import core.algo.Learning;
+import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
-import lombok.Setter;
import java.util.*;
@@ -26,11 +25,11 @@ import java.util.*;
* How to encounter this problem?
* @param
*/
-public class MonteCarloOnPolicyEGreedy extends Learning {
+public class MonteCarloOnPolicyEGreedy extends EpisodicLearning {
public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
-
+ currentEpisode = 0;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
}
@@ -47,8 +46,15 @@ public class MonteCarloOnPolicyEGreedy extends Learning {
Map, Integer> returnCount = new HashMap<>();
for(int i = 0; i < nrOfEpisodes; ++i) {
+ ++currentEpisode;
List> episode = new ArrayList<>();
State state = environment.reset();
+ dispatchEpisodeStart();
+ try {
+ Thread.sleep(delay);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
double sumOfRewards = 0;
for(int j=0; j < 10; ++j){
Map actionValues = stateActionTable.getActionValues(state);
@@ -67,6 +73,7 @@ public class MonteCarloOnPolicyEGreedy extends Learning {
} catch (InterruptedException e) {
e.printStackTrace();
}
+ dispatchStepEnd();
}
dispatchEpisodeEnd(sumOfRewards);
@@ -100,4 +107,9 @@ public class MonteCarloOnPolicyEGreedy extends Learning {
}
}
}
+
+ @Override
+ public int getCurrentEpisode() {
+ return currentEpisode;
+ }
}
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 80b4375..8ab5094 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -7,11 +7,12 @@ import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.View;
+import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import javax.swing.*;
-public class RLController implements ViewListener{
+public class RLController implements ViewListener {
protected Environment environment;
protected Learning learning;
protected DiscreteActionSpace discreteActionSpace;
@@ -19,6 +20,7 @@ public class RLController implements ViewListener{
private int delay;
private int nrOfEpisodes;
private Method method;
+ private int prevDelay;
public RLController(){
}
@@ -41,7 +43,7 @@ public class RLController implements ViewListener{
not using SwingUtilities here on purpose to ensure the view is fully
initialized and can be passed as LearningListener.
*/
- view = new View<>(learning, this);
+ view = new View<>(learning, environment, this);
learning.addListener(view);
learning.learn(nrOfEpisodes);
}
@@ -58,10 +60,23 @@ public class RLController implements ViewListener{
@Override
public void onDelayChange(int delay) {
+ changeLearningDelay(delay);
+ }
+
+ private void changeLearningDelay(int delay){
learning.setDelay(delay);
- SwingUtilities.invokeLater(() -> {
- view.updateLearningInfoPanel();
- });
+ SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
+ }
+
+ @Override
+ public void onFastLearnChange(boolean fastLearn) {
+ view.setDrawEveryStep(!fastLearn);
+ if(fastLearn){
+ prevDelay = learning.getDelay();
+ changeLearningDelay(0);
+ }else{
+ changeLearningDelay(prevDelay);
+ }
}
public RLController setMethod(Method method){
diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java
index 8c67589..4ed4acb 100644
--- a/src/main/java/core/gui/LearningInfoPanel.java
+++ b/src/main/java/core/gui/LearningInfoPanel.java
@@ -1,7 +1,8 @@
package core.gui;
+import core.algo.Episodic;
import core.algo.Learning;
-import core.controller.ViewListener;
+import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import javax.swing.*;
@@ -11,39 +12,62 @@ public class LearningInfoPanel extends JPanel {
private JLabel policyLabel;
private JLabel discountLabel;
private JLabel epsilonLabel;
+ private JLabel episodeLabel;
private JSlider epsilonSlider;
private JLabel delayLabel;
private JSlider delaySlider;
+ private JButton toggleFastLearningButton;
+ private boolean fastLearning;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
this.learning = learning;
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
policyLabel = new JLabel();
discountLabel = new JLabel();
- epsilonLabel = new JLabel();
delayLabel = new JLabel();
+ if(learning instanceof Episodic){
+ episodeLabel = new JLabel();
+ add(episodeLabel);
+ }
delaySlider = new JSlider(0,1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
if(learning.getPolicy() instanceof EpsilonPolicy){
+ epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(epsilonLabel);
add(epsilonSlider);
}
+
+ toggleFastLearningButton = new JButton("Enable fast-learn");
+ fastLearning = false;
+ toggleFastLearningButton.addActionListener(e->{
+ fastLearning = !fastLearning;
+ delaySlider.setEnabled(!fastLearning);
+ epsilonSlider.setEnabled(!fastLearning);
+ viewListener.onFastLearnChange(fastLearning);
+ });
add(delayLabel);
add(delaySlider);
+ add(toggleFastLearningButton);
refreshLabels();
setVisible(true);
}
- public void refreshLabels(){
+ public void refreshLabels() {
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
- if(learning.getPolicy() instanceof EpsilonPolicy){
- epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon());
+ if(learning instanceof Episodic){
+ episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode());
+ }
+ if (learning.getPolicy() instanceof EpsilonPolicy) {
+ epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
+ epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
}
delayLabel.setText("Delay (ms): " + learning.getDelay());
+ delaySlider.setValue(learning.getDelay());
+ toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
}
}
diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java
index d031451..5dff7e6 100644
--- a/src/main/java/core/gui/View.java
+++ b/src/main/java/core/gui/View.java
@@ -1,7 +1,8 @@
package core.gui;
+import core.Environment;
import core.algo.Learning;
-import core.controller.ViewListener;
+import core.listener.ViewListener;
import core.listener.LearningListener;
import lombok.Getter;
import org.knowm.xchart.QuickChart;
@@ -10,27 +11,31 @@ import org.knowm.xchart.XYChart;
import javax.swing.*;
import java.awt.*;
-import java.util.ArrayList;
import java.util.List;
public class View implements LearningListener {
private Learning learning;
+ private Environment environment;
@Getter
- private XYChart chart;
+ private XYChart rewardChart;
@Getter
private LearningInfoPanel learningInfoPanel;
@Getter
private JFrame mainFrame;
+ private JFrame environmentFrame;
private XChartPanel rewardChartPanel;
private ViewListener viewListener;
+ private boolean drawEveryStep;
- public View(Learning learning, ViewListener viewListener){
+ public View(Learning learning, Environment environment, ViewListener viewListener) {
this.learning = learning;
+ this.environment = environment;
this.viewListener = viewListener;
- this.initMainFrame();
+ drawEveryStep = true;
+ SwingUtilities.invokeLater(this::initMainFrame);
}
- private void initMainFrame(){
+ private void initMainFrame() {
mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setLayout(new BorderLayout());
@@ -44,29 +49,40 @@ public class View implements LearningListener {
mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
mainFrame.pack();
mainFrame.setVisible(true);
+
+ if (environment instanceof Visualizable) {
+ environmentFrame = new JFrame() {
+ {
+ add(((Visualizable) environment).visualize());
+ pack();
+ setVisible(true);
+ }
+ };
+
+ }
}
- private void initLearningInfoPanel(){
+ private void initLearningInfoPanel() {
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
}
- private void initRewardChart(){
- chart =
+ private void initRewardChart() {
+ rewardChart =
QuickChart.getChart(
- "Rewards per Episode",
+ "Sum of Rewards per Episode",
"Episode",
"Reward",
- "randomWalk",
- new double[] {0},
- new double[] {0});
- chart.getStyler().setLegendVisible(true);
- chart.getStyler().setXAxisTicksVisible(true);
- rewardChartPanel = new XChartPanel<>(chart);
- rewardChartPanel.setPreferredSize(new Dimension(300,300));
+ "rewardHistory",
+ new double[]{0},
+ new double[]{0});
+ rewardChart.getStyler().setLegendVisible(true);
+ rewardChart.getStyler().setXAxisTicksVisible(true);
+ rewardChartPanel = new XChartPanel<>(rewardChart);
+ rewardChartPanel.setPreferredSize(new Dimension(300, 300));
}
- public void showState(Visualizable state){
- new JFrame(){
+ public void showState(Visualizable state) {
+ new JFrame() {
{
JComponent stateComponent = state.visualize();
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
@@ -76,25 +92,47 @@ public class View implements LearningListener {
};
}
- public void updateRewardGraph(List rewardHistory){
- chart.updateXYSeries("randomWalk", null, rewardHistory, null);
+ public void setDrawEveryStep(boolean drawEveryStep){
+ this.drawEveryStep = drawEveryStep;
+ }
+
+ public void updateRewardGraph(List rewardHistory) {
+ rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null);
rewardChartPanel.revalidate();
rewardChartPanel.repaint();
}
- public void updateLearningInfoPanel(){
+ public void updateLearningInfoPanel() {
this.learningInfoPanel.refreshLabels();
}
@Override
public void onEpisodeEnd(List rewardHistory) {
- SwingUtilities.invokeLater(()->{
- updateRewardGraph(rewardHistory);
- });
+ SwingUtilities.invokeLater(() ->{
+ if(drawEveryStep){
+ updateRewardGraph(rewardHistory);
+ }
+ updateLearningInfoPanel();
+ });
}
@Override
public void onEpisodeStart() {
+ if(drawEveryStep) {
+ SwingUtilities.invokeLater(this::repaintEnvironment);
+ }
+ }
+ @Override
+ public void onStepEnd() {
+ if(drawEveryStep){
+ SwingUtilities.invokeLater(this::repaintEnvironment);
+ }
+ }
+
+ private void repaintEnvironment(){
+ if (environmentFrame != null) {
+ environmentFrame.repaint();
+ }
}
}
diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java
index 4147897..add1fbd 100644
--- a/src/main/java/core/listener/LearningListener.java
+++ b/src/main/java/core/listener/LearningListener.java
@@ -5,4 +5,5 @@ import java.util.List;
public interface LearningListener{
void onEpisodeEnd(List rewardHistory);
void onEpisodeStart();
+ void onStepEnd();
}
diff --git a/src/main/java/core/controller/ViewListener.java b/src/main/java/core/listener/ViewListener.java
similarity index 60%
rename from src/main/java/core/controller/ViewListener.java
rename to src/main/java/core/listener/ViewListener.java
index 578512a..f27d7b4 100644
--- a/src/main/java/core/controller/ViewListener.java
+++ b/src/main/java/core/listener/ViewListener.java
@@ -1,6 +1,7 @@
-package core.controller;
+package core.listener;
public interface ViewListener {
void onEpsilonChange(float epsilon);
void onDelayChange(int delay);
+ void onFastLearnChange(boolean isFastLearn);
}
diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java
index 1288aed..550889f 100644
--- a/src/main/java/core/policy/EpsilonGreedyPolicy.java
+++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java
@@ -29,7 +29,6 @@ public class EpsilonGreedyPolicy implements EpsilonPolicy{
@Override
public A chooseAction(Map actionValues) {
- System.out.println("current epsilon " + epsilon);
if(RNG.getRandom().nextFloat() < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);
diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java
index e512ad7..ebd0de6 100644
--- a/src/main/java/evironment/antGame/AntWorld.java
+++ b/src/main/java/evironment/antGame/AntWorld.java
@@ -1,14 +1,15 @@
package evironment.antGame;
-import core.*;
-import core.algo.Learning;
-import core.algo.mc.MonteCarloOnPolicyEGreedy;
-import evironment.antGame.gui.MainFrame;
-
+import core.Environment;
+import core.State;
+import core.StepResultEnvironment;
+import core.gui.Visualizable;
+import evironment.antGame.gui.AntWorldComponent;
+import javax.swing.*;
import java.awt.*;
-public class AntWorld implements Environment{
+public class AntWorld implements Environment, Visualizable {
/**
*
*/
@@ -204,14 +205,8 @@ public class AntWorld implements Environment{
return myAnt;
}
- public static void main(String[] args) {
- RNG.setSeed(1993);
-
- Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>(
- new AntWorld(3, 3, 0.1),
- new ListDiscreteActionSpace<>(AntAction.values()),
- 5
- );
- monteCarlo.learn(20000);
+ @Override
+ public JComponent visualize() {
+ return new AntWorldComponent(this, this.antAgent);
}
}
diff --git a/src/main/java/evironment/antGame/gui/MainFrame.java b/src/main/java/evironment/antGame/gui/AntWorldComponent.java
similarity index 57%
rename from src/main/java/evironment/antGame/gui/MainFrame.java
rename to src/main/java/evironment/antGame/gui/AntWorldComponent.java
index c299d78..7edb62b 100644
--- a/src/main/java/evironment/antGame/gui/MainFrame.java
+++ b/src/main/java/evironment/antGame/gui/AntWorldComponent.java
@@ -1,21 +1,18 @@
package evironment.antGame.gui;
-import core.StepResultEnvironment;
-import evironment.antGame.AntAction;
import evironment.antGame.AntAgent;
import evironment.antGame.AntWorld;
import javax.swing.*;
import java.awt.*;
-public class MainFrame extends JFrame {
+public class AntWorldComponent extends JComponent {
private AntWorld antWorld;
private HistoryPanel historyPanel;
- public MainFrame(AntWorld antWorld, AntAgent antAgent){
+ public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
this.antWorld = antWorld;
setLayout(new BorderLayout());
- setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
historyPanel = new HistoryPanel();
@@ -29,14 +26,11 @@ public class MainFrame extends JFrame {
add(BorderLayout.CENTER, mapComponent);
add(BorderLayout.SOUTH, historyPanel);
- pack();
setVisible(true);
}
- public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
- historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
- antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
-
- repaint();
+ @Override
+ protected void paintComponent(Graphics g) {
+ super.paintComponent(g);
}
}
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
index dc22cc0..0106c30 100644
--- a/src/main/java/example/RunningAnt.java
+++ b/src/main/java/example/RunningAnt.java
@@ -8,14 +8,14 @@ import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
- RNG.setSeed(1234);
+ RNG.setSeed(123);
RLController rl = new RLController()
.setEnvironment(new AntWorld(3,3,0.1))
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(10)
- .setEpisodes(10000);
+ .setEpisodes(100000);
rl.start();
}
}