From b1246f62cc9a58a019da5ba6b37de8739f7ab7fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sun, 22 Dec 2019 17:06:54 +0100 Subject: [PATCH] add features to gui to control learning and moving learning listener interface to controller - Add metric to display episodes per second - view not implementing learning listener anymore, controller does. Controller is controlling all view actions based upon learning events. Reacts to view events via viewListener - add executor service for learning task - using instance of to distinguish between episodic learning and td learning - add feature to trigger more episodes - add checkboxes for smoothing graph, displaying last 100 rewards only and drawing environment - remove history panel from antworld gui --- src/main/java/core/Util.java | 15 +++ src/main/java/core/algo/Episodic.java | 2 + src/main/java/core/algo/EpisodicLearning.java | 53 ++++++++ src/main/java/core/algo/Learning.java | 30 +++-- .../algo/MC/MonteCarloOnPolicyEGreedy.java | 114 +++++++++--------- .../java/core/controller/RLController.java | 92 ++++++++++++-- src/main/java/core/gui/LearningInfoPanel.java | 53 +++++++- src/main/java/core/gui/LearningView.java | 9 ++ src/main/java/core/gui/View.java | 87 +++++++------ .../java/core/listener/LearningListener.java | 2 + src/main/java/core/listener/ViewListener.java | 1 + .../antGame/gui/AntWorldComponent.java | 4 - .../evironment/antGame/gui/HistoryPanel.java | 28 ----- src/main/java/example/RunningAnt.java | 2 +- 14 files changed, 337 insertions(+), 155 deletions(-) create mode 100644 src/main/java/core/Util.java create mode 100644 src/main/java/core/gui/LearningView.java delete mode 100644 src/main/java/evironment/antGame/gui/HistoryPanel.java diff --git a/src/main/java/core/Util.java b/src/main/java/core/Util.java new file mode 100644 index 0000000..3c1324f --- /dev/null +++ b/src/main/java/core/Util.java @@ -0,0 +1,15 @@ +package core; + +public class Util { + public static boolean isNumeric(String strNum) { + if (strNum == null) { + return false; + } + try { + double d = Double.parseDouble(strNum); + } catch (NumberFormatException nfe) { + return false; + } + return true; + } +} diff --git a/src/main/java/core/algo/Episodic.java b/src/main/java/core/algo/Episodic.java index f846c84..cf56205 100644 --- a/src/main/java/core/algo/Episodic.java +++ b/src/main/java/core/algo/Episodic.java @@ -2,4 +2,6 @@ package core.algo; public interface Episodic { int getCurrentEpisode(); + int getEpisodesToGo(); + int getEpisodesPerSecond(); } diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 7f6af1b..f9aa580 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -2,9 +2,14 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; +import core.listener.LearningListener; public abstract class EpisodicLearning extends Learning implements Episodic{ protected int currentEpisode; + protected int episodesToLearn; + protected volatile int episodePerSecond; + protected int episodeSumCurrentSecond; + private volatile boolean meseaureEpisodeBenchMark; public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { super(environment, actionSpace, discountFactor, delay); @@ -22,8 +27,56 @@ public abstract class EpisodicLearning extends Learning imple super(environment, actionSpace); } + protected void dispatchEpisodeEnd(double recentSumOfRewards){ + ++episodeSumCurrentSecond; + rewardHistory.add(recentSumOfRewards); + for(LearningListener l: learningListeners) { + l.onEpisodeEnd(rewardHistory); + } + } + + protected void dispatchEpisodeStart(){ + for(LearningListener l: learningListeners){ + l.onEpisodeStart(); + } + } + + @Override + public void learn(){ + learn(0); + } + + public void learn(int nrOfEpisodes){ + meseaureEpisodeBenchMark = true; + new Thread(()->{ + while(meseaureEpisodeBenchMark){ + episodePerSecond = episodeSumCurrentSecond; + episodeSumCurrentSecond = 0; + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + }).start(); + episodesToLearn += nrOfEpisodes; + dispatchLearningStart(); + for(int i=0; i < nrOfEpisodes; ++i){ + nextEpisode(); + } + dispatchLearningEnd(); + meseaureEpisodeBenchMark = false; + } + + protected abstract void nextEpisode(); + @Override public int getCurrentEpisode(){ return currentEpisode; } + + @Override + public int getEpisodesToGo(){ + return episodesToLearn - currentEpisode; + } } diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index c7058fb..1999f46 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -24,7 +24,7 @@ public abstract class Learning { protected Set learningListeners; @Setter protected int delay; - private List rewardHistory; + protected List rewardHistory; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay){ this.environment = environment; @@ -47,29 +47,27 @@ public abstract class Learning { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } - - public abstract void learn(int nrOfEpisodes); + public abstract void learn(); public void addListener(LearningListener learningListener){ learningListeners.add(learningListener); } - protected void dispatchEpisodeEnd(double recentSumOfRewards){ - rewardHistory.add(recentSumOfRewards); - for(LearningListener l: learningListeners) { - l.onEpisodeEnd(rewardHistory); - } - } - - protected void dispatchEpisodeStart(){ - for(LearningListener l: learningListeners){ - l.onEpisodeStart(); - } - } - protected void dispatchStepEnd(){ for(LearningListener l: learningListeners){ l.onStepEnd(); } } + + protected void dispatchLearningStart(){ + for(LearningListener l: learningListeners){ + l.onLearningStart(); + } + } + + protected void dispatchLearningEnd(){ + for(LearningListener l: learningListeners){ + l.onLearningEnd(); + } + } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index a9a20bc..544fd1b 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -11,27 +11,33 @@ import java.util.*; * TODO: Major problem: * StateActionPairs are only unique accounting for their position in the episode. * For example: - * + *

* startingState -> MOVE_LEFT : very first state action in the episode i = 1 * image the agent does not collect the food and drops it to the start, the agent will receive * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10; - * + *

* BUT image moving left from the starting position will have no impact on the state because * the agent ran into a wall. The known world stays the same. * Taking an action after that will have the exact same state but a different action * making the value of this stateActionPair -9 because the stateAction pair took place on the second * timestamp, summing up all remaining rewards will be -9... - * + *

* How to encounter this problem? + * * @param */ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning { + private Map, Double> returnSum; + private Map, Integer> returnCount; + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); currentEpisode = 0; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new StateActionHashTable<>(this.actionSpace); + returnSum = new HashMap<>(); + returnCount = new HashMap<>(); } public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { @@ -40,71 +46,64 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< @Override - public void learn(int nrOfEpisodes) { + public void nextEpisode() { + ++currentEpisode; + List> episode = new ArrayList<>(); + State state = environment.reset(); + dispatchEpisodeStart(); + try { + Thread.sleep(delay); + } catch (InterruptedException e) { + e.printStackTrace(); + } + double sumOfRewards = 0; + for (int j = 0; j < 10; ++j) { + Map actionValues = stateActionTable.getActionValues(state); + A chosenAction = policy.chooseAction(actionValues); + StepResultEnvironment envResult = environment.step(chosenAction); + State nextState = envResult.getState(); + sumOfRewards += envResult.getReward(); + episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); - Map, Double> returnSum = new HashMap<>(); - Map, Integer> returnCount = new HashMap<>(); + if (envResult.isDone()) break; + + state = nextState; - for(int i = 0; i < nrOfEpisodes; ++i) { - ++currentEpisode; - List> episode = new ArrayList<>(); - State state = environment.reset(); - dispatchEpisodeStart(); try { Thread.sleep(delay); } catch (InterruptedException e) { e.printStackTrace(); } - double sumOfRewards = 0; - for(int j=0; j < 10; ++j){ - Map actionValues = stateActionTable.getActionValues(state); - A chosenAction = policy.chooseAction(actionValues); - StepResultEnvironment envResult = environment.step(chosenAction); - State nextState = envResult.getState(); - sumOfRewards += envResult.getReward(); - episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); + dispatchStepEnd(); + } - if(envResult.isDone()) break; + dispatchEpisodeEnd(sumOfRewards); + System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); + Set> stateActionPairs = new HashSet<>(); - state = nextState; - - try { - Thread.sleep(delay); - } catch (InterruptedException e) { - e.printStackTrace(); + for (StepResult sr : episode) { + stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); + } + System.out.println("stateActionPairs " + stateActionPairs.size()); + for (Pair stateActionPair : stateActionPairs) { + int firstOccurenceIndex = 0; + // find first occurance of state action pair + for (StepResult sr : episode) { + if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { + break; } - dispatchStepEnd(); + firstOccurenceIndex++; } - dispatchEpisodeEnd(sumOfRewards); - System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards); - Set> stateActionPairs = new HashSet<>(); - - for(StepResult sr: episode){ - stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); - } - System.out.println("stateActionPairs " + stateActionPairs.size()); - for(Pair stateActionPair: stateActionPairs){ - int firstOccurenceIndex = 0; - // find first occurance of state action pair - for(StepResult sr: episode){ - if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){ -; - break; - } - firstOccurenceIndex++; - } - - double G = 0; - for(int l = firstOccurenceIndex; l < episode.size(); ++l){ - G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); - } - // slick trick to add G to the entry. - // if the key does not exists, it will create a new entry with G as default value - returnSum.merge(stateActionPair, G, Double::sum); - returnCount.merge(stateActionPair, 1, Integer::sum); - stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); + double G = 0; + for (int l = firstOccurenceIndex; l < episode.size(); ++l) { + G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); } + // slick trick to add G to the entry. + // if the key does not exists, it will create a new entry with G as default value + returnSum.merge(stateActionPair, G, Double::sum); + returnCount.merge(stateActionPair, 1, Integer::sum); + stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); } } @@ -112,4 +111,9 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< public int getCurrentEpisode() { return currentEpisode; } + + @Override + public int getEpisodesPerSecond(){ + return episodePerSecond; + } } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 8ab5094..4a149e5 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -3,26 +3,37 @@ package core.controller; import core.DiscreteActionSpace; import core.Environment; import core.ListDiscreteActionSpace; +import core.algo.EpisodicLearning; import core.algo.Learning; import core.algo.Method; import core.algo.mc.MonteCarloOnPolicyEGreedy; +import core.gui.LearningView; import core.gui.View; +import core.listener.LearningListener; import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; -public class RLController implements ViewListener { +public class RLController implements ViewListener, LearningListener { protected Environment environment; protected Learning learning; protected DiscreteActionSpace discreteActionSpace; - protected View view; + protected LearningView learningView; private int delay; private int nrOfEpisodes; private Method method; private int prevDelay; + private boolean fastLearning; + private boolean currentlyLearning; + private ExecutorService learningExecutor; + private List latestRewardsHistory; public RLController(){ + learningExecutor = Executors.newSingleThreadExecutor(); } public void start(){ @@ -39,20 +50,37 @@ public class RLController implements ViewListener { default: throw new RuntimeException("Undefined method"); } - /* - not using SwingUtilities here on purpose to ensure the view is fully - initialized and can be passed as LearningListener. - */ - view = new View<>(learning, environment, this); - learning.addListener(view); - learning.learn(nrOfEpisodes); + SwingUtilities.invokeLater(()->{ + learningView = new View<>(learning, environment, this); + learning.addListener(this); + }); + + if(learning instanceof EpisodicLearning){ + learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); + }else{ + learningExecutor.submit(()->learning.learn()); + } + } + + /************************************************* + * VIEW LISTENERS * + *************************************************/ + @Override + public void onLearnMoreEpisodes(int nrOfEpisodes){ + if(!currentlyLearning){ + if(learning instanceof EpisodicLearning){ + learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); + }else{ + throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); + } + } } @Override public void onEpsilonChange(float epsilon) { if(learning.getPolicy() instanceof EpsilonPolicy){ ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon); - SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); }else{ System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); } @@ -65,12 +93,12 @@ public class RLController implements ViewListener { private void changeLearningDelay(int delay){ learning.setDelay(delay); - SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); } @Override public void onFastLearnChange(boolean fastLearn) { - view.setDrawEveryStep(!fastLearn); + this.fastLearning = fastLearn; if(fastLearn){ prevDelay = learning.getDelay(); changeLearningDelay(0); @@ -79,6 +107,45 @@ public class RLController implements ViewListener { } } + /************************************************* + * LEARNING LISTENERS * + *************************************************/ + @Override + public void onLearningStart() { + currentlyLearning = true; + } + + @Override + public void onLearningEnd() { + currentlyLearning = false; + SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory)); + } + + @Override + public void onEpisodeEnd(List rewardHistory) { + latestRewardsHistory = rewardHistory; + SwingUtilities.invokeLater(() ->{ + if(!fastLearning){ + learningView.updateRewardGraph(latestRewardsHistory); + } + learningView.updateLearningInfoPanel(); + }); + } + + @Override + public void onEpisodeStart() { + + } + + @Override + public void onStepEnd() { + if(!fastLearning){ + SwingUtilities.invokeLater(() -> learningView.repaintEnvironment()); + } + } + + + public RLController setMethod(Method method){ this.method = method; return this; @@ -102,5 +169,4 @@ public class RLController implements ViewListener { this.nrOfEpisodes = nrOfEpisodes; return this; } - } diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index 4ed4acb..9e25b51 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -1,11 +1,14 @@ package core.gui; +import core.Util; import core.algo.Episodic; +import core.algo.EpisodicLearning; import core.algo.Learning; import core.listener.ViewListener; import core.policy.EpsilonPolicy; import javax.swing.*; +import java.awt.*; public class LearningInfoPanel extends JPanel { private Learning learning; @@ -18,6 +21,11 @@ public class LearningInfoPanel extends JPanel { private JSlider delaySlider; private JButton toggleFastLearningButton; private boolean fastLearning; + private JCheckBox smoothGraphCheckbox; + private JCheckBox last100Checkbox; + private JCheckBox drawEnvironmentCheckbox; + private JTextField learnMoreEpisodesInput; + private JButton learnMoreEpisodesButton; public LearningInfoPanel(Learning learning, ViewListener viewListener){ this.learning = learning; @@ -47,11 +55,37 @@ public class LearningInfoPanel extends JPanel { fastLearning = !fastLearning; delaySlider.setEnabled(!fastLearning); epsilonSlider.setEnabled(!fastLearning); + drawEnvironmentCheckbox.setSelected(!fastLearning); viewListener.onFastLearnChange(fastLearning); }); + smoothGraphCheckbox = new JCheckBox("Smoothen Graph"); + smoothGraphCheckbox.setSelected(false); + last100Checkbox = new JCheckBox("Only show last 100 Rewards"); + last100Checkbox.setSelected(true); + drawEnvironmentCheckbox = new JCheckBox("Update Environment"); + drawEnvironmentCheckbox.setSelected(true); + add(delayLabel); add(delaySlider); add(toggleFastLearningButton); + + if(learning instanceof EpisodicLearning) { + learnMoreEpisodesInput = new JTextField(); + learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20)); + learnMoreEpisodesButton = new JButton("Learn More Episodes"); + learnMoreEpisodesButton.addActionListener(e -> { + if (Util.isNumeric(learnMoreEpisodesInput.getText())) { + viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText())); + } else { + learnMoreEpisodesInput.setText(""); + } + }); + add(learnMoreEpisodesInput); + add(learnMoreEpisodesButton); + } + add(drawEnvironmentCheckbox); + add(smoothGraphCheckbox); + add(last100Checkbox); refreshLabels(); setVisible(true); } @@ -60,14 +94,29 @@ public class LearningInfoPanel extends JPanel { policyLabel.setText("Policy: " + learning.getPolicy().getClass()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); if(learning instanceof Episodic){ - episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode()); + episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() + + "\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() + + "\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond()); } if (learning.getPolicy() instanceof EpsilonPolicy) { epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon()); epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100)); } delayLabel.setText("Delay (ms): " + learning.getDelay()); - delaySlider.setValue(learning.getDelay()); + if(delaySlider.isEnabled()){ + delaySlider.setValue(learning.getDelay()); + } toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning"); } + + protected boolean isSmoothenGraphSelected() { + return smoothGraphCheckbox.isSelected(); + } + protected boolean isLast100Selected(){ + return last100Checkbox.isSelected(); + } + + protected boolean isDrawEnvironmentSelected(){ + return drawEnvironmentCheckbox.isSelected(); + } } diff --git a/src/main/java/core/gui/LearningView.java b/src/main/java/core/gui/LearningView.java new file mode 100644 index 0000000..6a4ceaa --- /dev/null +++ b/src/main/java/core/gui/LearningView.java @@ -0,0 +1,9 @@ +package core.gui; + +import java.util.List; + +public interface LearningView { + void repaintEnvironment(); + void updateLearningInfoPanel(); + void updateRewardGraph(final List rewardHistory); +} diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index 5dff7e6..d89cc97 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -3,7 +3,7 @@ package core.gui; import core.Environment; import core.algo.Learning; import core.listener.ViewListener; -import core.listener.LearningListener; +import javafx.util.Pair; import lombok.Getter; import org.knowm.xchart.QuickChart; import org.knowm.xchart.XChartPanel; @@ -12,8 +12,9 @@ import org.knowm.xchart.XYChart; import javax.swing.*; import java.awt.*; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; -public class View implements LearningListener { +public class View implements LearningView{ private Learning learning; private Environment environment; @Getter @@ -25,14 +26,12 @@ public class View implements LearningListener { private JFrame environmentFrame; private XChartPanel rewardChartPanel; private ViewListener viewListener; - private boolean drawEveryStep; public View(Learning learning, Environment environment, ViewListener viewListener) { this.learning = learning; this.environment = environment; this.viewListener = viewListener; - drawEveryStep = true; - SwingUtilities.invokeLater(this::initMainFrame); + initMainFrame(); } private void initMainFrame() { @@ -92,46 +91,62 @@ public class View implements LearningListener { }; } - public void setDrawEveryStep(boolean drawEveryStep){ - this.drawEveryStep = drawEveryStep; - } + public void updateRewardGraph(final List rewardHistory) { + List xValues; + List yValues; + if(learningInfoPanel.isLast100Selected()){ + yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size())); + xValues = new CopyOnWriteArrayList<>(); + for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i , List> XYvalues = smoothenGraph(rewardHistory); + xValues = XYvalues.getKey(); + yValues = XYvalues.getValue(); + }else{ + xValues = null; + yValues = rewardHistory; + } + } - public void updateRewardGraph(List rewardHistory) { - rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null); + rewardChart.updateXYSeries("rewardHistory", xValues, yValues, null); rewardChartPanel.revalidate(); rewardChartPanel.repaint(); } + private Pair, List> smoothenGraph(List original){ + int totalXPoints = 100; + + List xValues = new CopyOnWriteArrayList<>(); + List tmp = new CopyOnWriteArrayList<>(); + int meanBatch = original.size() / totalXPoints; + if(meanBatch < 1){ + meanBatch = 1; + } + + int idx = 0; + int batchIdx = 0; + double batchSum = 0; + for(Double x: original) { + ++idx; + batchSum += x; + if (idx == 1 || ++batchIdx % meanBatch == 0) { + tmp.add(batchSum / meanBatch); + xValues.add(idx); + batchSum = 0; + } + } + return new Pair<>(xValues, tmp); + } + public void updateLearningInfoPanel() { this.learningInfoPanel.refreshLabels(); } - @Override - public void onEpisodeEnd(List rewardHistory) { - SwingUtilities.invokeLater(() ->{ - if(drawEveryStep){ - updateRewardGraph(rewardHistory); - } - updateLearningInfoPanel(); - }); - } - - @Override - public void onEpisodeStart() { - if(drawEveryStep) { - SwingUtilities.invokeLater(this::repaintEnvironment); - } - } - - @Override - public void onStepEnd() { - if(drawEveryStep){ - SwingUtilities.invokeLater(this::repaintEnvironment); - } - } - - private void repaintEnvironment(){ - if (environmentFrame != null) { + public void repaintEnvironment(){ + if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) { environmentFrame.repaint(); } } diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java index add1fbd..2891d16 100644 --- a/src/main/java/core/listener/LearningListener.java +++ b/src/main/java/core/listener/LearningListener.java @@ -3,6 +3,8 @@ package core.listener; import java.util.List; public interface LearningListener{ + void onLearningStart(); + void onLearningEnd(); void onEpisodeEnd(List rewardHistory); void onEpisodeStart(); void onStepEnd(); diff --git a/src/main/java/core/listener/ViewListener.java b/src/main/java/core/listener/ViewListener.java index f27d7b4..dbf01d4 100644 --- a/src/main/java/core/listener/ViewListener.java +++ b/src/main/java/core/listener/ViewListener.java @@ -4,4 +4,5 @@ public interface ViewListener { void onEpsilonChange(float epsilon); void onDelayChange(int delay); void onFastLearnChange(boolean isFastLearn); + void onLearnMoreEpisodes(int nrOfEpisodes); } diff --git a/src/main/java/evironment/antGame/gui/AntWorldComponent.java b/src/main/java/evironment/antGame/gui/AntWorldComponent.java index 7edb62b..4b148b5 100644 --- a/src/main/java/evironment/antGame/gui/AntWorldComponent.java +++ b/src/main/java/evironment/antGame/gui/AntWorldComponent.java @@ -8,14 +8,12 @@ import java.awt.*; public class AntWorldComponent extends JComponent { private AntWorld antWorld; - private HistoryPanel historyPanel; public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){ this.antWorld = antWorld; setLayout(new BorderLayout()); CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10); CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10); - historyPanel = new HistoryPanel(); JComponent mapComponent = new JPanel(); FlowLayout flowLayout = new FlowLayout(); @@ -23,9 +21,7 @@ public class AntWorldComponent extends JComponent { mapComponent.setLayout(flowLayout); mapComponent.add(worldPane); mapComponent.add(antBrainPane); - add(BorderLayout.CENTER, mapComponent); - add(BorderLayout.SOUTH, historyPanel); setVisible(true); } diff --git a/src/main/java/evironment/antGame/gui/HistoryPanel.java b/src/main/java/evironment/antGame/gui/HistoryPanel.java deleted file mode 100644 index d7106a6..0000000 --- a/src/main/java/evironment/antGame/gui/HistoryPanel.java +++ /dev/null @@ -1,28 +0,0 @@ -package evironment.antGame.gui; - -import javax.swing.*; -import java.awt.*; - -public class HistoryPanel extends JPanel { - private final int panelWidth = 1000; - private final int panelHeight = 300; - private JTextArea textArea; - - public HistoryPanel(){ - setPreferredSize(new Dimension(panelWidth, panelHeight)); - textArea = new JTextArea(); - textArea.setLineWrap(true); - textArea.setWrapStyleWord(true); - textArea.setEditable(false); - JScrollPane scrollBar = new JScrollPane(textArea, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS); - scrollBar.setPreferredSize(new Dimension(panelWidth, panelHeight)); - add(scrollBar); - setVisible(true); - } - - public void addText(String toAppend){ - textArea.append(toAppend); - textArea.append("\n\n"); - revalidate(); - } -} diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index 0106c30..307f107 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -14,7 +14,7 @@ public class RunningAnt { .setEnvironment(new AntWorld(3,3,0.1)) .setAllowedActions(AntAction.values()) .setMethod(Method.MC_ONPOLICY_EGREEDY) - .setDelay(10) + .setDelay(200) .setEpisodes(100000); rl.start(); }