diff --git a/build.gradle b/build.gradle index c44307d..48a4181 100644 --- a/build.gradle +++ b/build.gradle @@ -20,9 +20,10 @@ dependencies { testCompile group: 'junit', name: 'junit', version: '4.12' compileOnly 'org.projectlombok:lombok:1.18.10' annotationProcessor 'org.projectlombok:lombok:1.18.10' - compile 'org.javatuples:javatuples:1.2' - // https://mvnrepository.com/artifact/org.javatuples/javatuples - compile group: 'org.javatuples', name: 'javatuples', version: '1.2' + // https://mvnrepository.com/artifact/org.apache.commons/commons-lang3 + compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.0' + // https://mvnrepository.com/artifact/org.apache.commons/commons-collections4 + compile group: 'org.apache.commons', name: 'commons-collections4', version: '4.1' } diff --git a/src/main/java/core/DeterministicStateActionTable.java b/src/main/java/core/DeterministicStateActionTable.java index fbc44ca..af492cb 100644 --- a/src/main/java/core/DeterministicStateActionTable.java +++ b/src/main/java/core/DeterministicStateActionTable.java @@ -1,9 +1,9 @@ package core; import java.io.Serializable; -import java.util.LinkedHashMap; -import java.util.Map; +import java.util.*; +import org.apache.commons.collections4.queue.CircularFifoQueue; /** * Premise: All states have the complete action space */ @@ -11,10 +11,12 @@ public class DeterministicStateActionTable implements StateActio private final Map> table; private DiscreteActionSpace discreteActionSpace; + private Queue>> latestChanges; public DeterministicStateActionTable(DiscreteActionSpace discreteActionSpace){ table = new LinkedHashMap<>(); this.discreteActionSpace = discreteActionSpace; + latestChanges = new CircularFifoQueue<>(10); } /** @@ -57,6 +59,7 @@ public class DeterministicStateActionTable implements StateActio actionValues = createDefaultActionValues(); table.put(state, actionValues); } + latestChanges.add(new AbstractMap.SimpleEntry<>(state, actionValues)); actionValues.put(action, value); } @@ -72,6 +75,11 @@ public class DeterministicStateActionTable implements StateActio return table.get(state); } + @Override + public Queue>> getFirstStateEntriesForView() { + return latestChanges; + } + /** * @return Map with initial values for every available action */ diff --git a/src/main/java/core/StateActionTable.java b/src/main/java/core/StateActionTable.java index 306f7d1..7e93479 100644 --- a/src/main/java/core/StateActionTable.java +++ b/src/main/java/core/StateActionTable.java @@ -1,6 +1,7 @@ package core; import java.util.Map; +import java.util.Queue; /** * Q-Table which saves all seen states, all available actions for each state @@ -15,4 +16,6 @@ public interface StateActionTable { void setValue(State state, A action, double value); int getStateCount(); Map getActionValues(State state); + + Queue>> getFirstStateEntriesForView(); } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index f1a6d89..c4ca4bc 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -3,7 +3,8 @@ package core.algo.mc; import core.*; import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; -import org.javatuples.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import java.io.IOException; import java.io.ObjectInputStream; @@ -80,7 +81,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< Set> stateActionPairs = new LinkedHashSet<>(); for (StepResult sr : episode) { - stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); + stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction())); } //System.out.println("stateActionPairs " + stateActionPairs.size()); @@ -88,7 +89,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< int firstOccurenceIndex = 0; // find first occurance of state action pair for (StepResult sr : episode) { - if (stateActionPair.getValue0().equals(sr.getState()) && stateActionPair.getValue1().equals(sr.getAction())) { + if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) { break; } firstOccurenceIndex++; @@ -102,7 +103,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< // if the key does not exists, it will create a new entry with G as default value returnSum.merge(stateActionPair, G, Double::sum); returnCount.merge(stateActionPair, 1, Integer::sum); - stateActionTable.setValue(stateActionPair.getValue0(), stateActionPair.getValue1(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); + stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); } } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index e2b9ec3..126ee2a 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -15,10 +15,9 @@ import lombok.Setter; import javax.swing.*; import java.io.*; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -public class RLController implements ViewListener, LearningListener { + +public class RLController implements LearningListener { protected final String folderPrefix = "learningStates" + File.separator; protected Environment environment; protected DiscreteActionSpace discreteActionSpace; @@ -37,15 +36,15 @@ public class RLController implements ViewListener, LearningListe protected int prevDelay; protected volatile boolean printNextEpisode; - public RLController(Environment env, Method method, A... actions){ + public RLController(Environment env, Method method, A... actions) { setEnvironment(env); setMethod(method); setAllowedActions(actions); printNextEpisode = true; } - public void start(){ - switch (method){ + public void start() { + switch(method) { case MC_ONPOLICY_EGREEDY: learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; @@ -60,13 +59,13 @@ public class RLController implements ViewListener, LearningListe initLearning(); } - protected void initListeners(){ + protected void initListeners() { learning.addListener(this); new Thread(() -> { - while (true){ + while(true) { printNextEpisode = true; try { - Thread.sleep(30*1000); + Thread.sleep(30 * 1000); } catch (InterruptedException e) { e.printStackTrace(); } @@ -74,35 +73,42 @@ public class RLController implements ViewListener, LearningListe }).start(); } - private void initLearning(){ - if(learning instanceof EpisodicLearning){ + private void initLearning() { + if(learning instanceof EpisodicLearning) { System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes"); - ((EpisodicLearning) learning).learn(nrOfEpisodes); - }else{ + ((EpisodicLearning) learning).learn(nrOfEpisodes); + } else { learning.learn(); } } - /************************************************* - ** VIEW LISTENERS ** - *************************************************/ - @Override - public void onLearnMoreEpisodes(int nrOfEpisodes){ - if(learning instanceof EpisodicLearning){ + protected void changeLearningDelay(int delay) { + learning.setDelay(delay); + } + + protected void learnMoreEpisodes(int nrOfEpisodes) { + if(learning instanceof EpisodicLearning) { ((EpisodicLearning) learning).learn(nrOfEpisodes); - }else{ + } else { throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); } } - @Override - public void onLoadState(String fileName) { + protected void changeEpsilon(float epsilon) { + if(learning.getPolicy() instanceof EpsilonPolicy) { + ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon); + } else { + System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); + } + } + + protected void saveState(String fileName) { FileInputStream fis; ObjectInputStream in; try { fis = new FileInputStream(fileName); in = new ObjectInputStream(fis); - System.out.println("interrupt" + Thread.currentThread().getId()); + System.out.println("interrup" + Thread.currentThread().getId()); learning.interruptLearning(); learning.load(in); in.close(); @@ -111,46 +117,26 @@ public class RLController implements ViewListener, LearningListe } } - @Override - public void onSaveState(String fileName) { + protected void loadState(String fileName) { FileOutputStream fos; ObjectOutputStream out; - try{ - fos = new FileOutputStream(folderPrefix + fileName); + try { + fos = new FileOutputStream(folderPrefix + fileName); out = new ObjectOutputStream(fos); learning.interruptLearning(); learning.save(out); out.close(); - }catch (IOException e){ + } catch (IOException e) { e.printStackTrace(); } } - @Override - public void onEpsilonChange(float epsilon) { - if(learning.getPolicy() instanceof EpsilonPolicy){ - ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon); - }else{ - System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); - } - } - - @Override - public void onDelayChange(int delay) { - changeLearningDelay(delay); - } - - protected void changeLearningDelay(int delay){ - learning.setDelay(delay); - } - - @Override - public void onFastLearnChange(boolean fastLearn) { + protected void changeFastLearning(boolean fastLearn) { this.fastLearning = fastLearn; - if(fastLearn){ + if(fastLearn) { prevDelay = learning.getDelay(); changeLearningDelay(0); - }else{ + } else { changeLearningDelay(prevDelay); } } @@ -165,7 +151,6 @@ public class RLController implements ViewListener, LearningListe @Override public void onLearningEnd() { System.out.println("Learning finished"); - onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : "")); } @Override @@ -176,9 +161,9 @@ public class RLController implements ViewListener, LearningListe @Override public void onEpisodeEnd(List rewardHistory) { latestRewardsHistory = rewardHistory; - if(printNextEpisode){ - System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1)); - System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); + if(printNextEpisode) { + System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); + System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); printNextEpisode = false; } } @@ -192,22 +177,22 @@ public class RLController implements ViewListener, LearningListe ** SETTERS ** *************************************************/ - private void setEnvironment(Environment environment){ - if(environment == null){ + private void setEnvironment(Environment environment) { + if(environment == null) { throw new IllegalArgumentException("Environment cannot be null"); } this.environment = environment; } - private void setMethod(Method method){ - if(method == null){ + private void setMethod(Method method) { + if(method == null) { throw new IllegalArgumentException("Method cannot be null"); } this.method = method; } - private void setAllowedActions(A[] actions){ - if(actions == null || actions.length == 0){ + private void setAllowedActions(A[] actions) { + if(actions == null || actions.length == 0) { throw new IllegalArgumentException("There has to be at least one action"); } this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); diff --git a/src/main/java/core/controller/RLControllerGUI.java b/src/main/java/core/controller/RLControllerGUI.java index 567ad54..be2ea56 100644 --- a/src/main/java/core/controller/RLControllerGUI.java +++ b/src/main/java/core/controller/RLControllerGUI.java @@ -1,14 +1,16 @@ package core.controller; import core.Environment; +import core.algo.EpisodicLearning; import core.algo.Method; import core.gui.LearningView; import core.gui.View; +import core.listener.ViewListener; import javax.swing.*; import java.util.List; -public class RLControllerGUI extends RLController { +public class RLControllerGUI extends RLController implements ViewListener { private LearningView learningView; public RLControllerGUI(Environment env, Method method, A... actions) { @@ -23,21 +25,41 @@ public class RLControllerGUI extends RLController { }); } - @Override - public void onLearnMoreEpisodes(int nrOfEpisodes) { - super.onLearnMoreEpisodes(nrOfEpisodes); - learningView.updateLearningInfoPanel(); - } + /************************************************* + ** View LISTENERS ** + *************************************************/ @Override public void onLoadState(String fileName) { - super.onLoadState(fileName); + super.loadState(fileName); SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); } + @Override + public void onSaveState(String fileName) { + super.saveState(fileName); + } + + @Override + public void onShowQTable() { + learningView.showQTableFrame(); + } + @Override public void onEpsilonChange(float epsilon) { - super.onEpsilonChange(epsilon); + super.changeEpsilon(epsilon); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); + } + + @Override + public void onDelayChange(int delay) { + super.changeLearningDelay(delay); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); + } + + @Override + public void onFastLearnChange(boolean isFastLearn) { + super.changeFastLearning(isFastLearn); SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); } @@ -48,17 +70,23 @@ public class RLControllerGUI extends RLController { } @Override - public void onLearningEnd() { - super.onLearningEnd(); - SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory)); + public void onLearnMoreEpisodes(int nrOfEpisodes) { + super.learnMoreEpisodes(nrOfEpisodes); + learningView.updateLearningInfoPanel(); } + + /************************************************* + ** LEARNING LISTENERS ** + *************************************************/ + @Override public void onEpisodeEnd(List rewardHistory) { super.onEpisodeEnd(rewardHistory); SwingUtilities.invokeLater(() -> { - if (!fastLearning) { + if(!fastLearning) { learningView.updateRewardGraph(latestRewardsHistory); + learningView.updateQTable(); } learningView.updateLearningInfoPanel(); }); @@ -66,8 +94,15 @@ public class RLControllerGUI extends RLController { @Override public void onStepEnd() { - if (!fastLearning) { + if(!fastLearning) { SwingUtilities.invokeLater(() -> learningView.repaintEnvironment()); } } + + @Override + public void onLearningEnd() { + super.onLearningEnd(); + onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : "")); + SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory)); + } } diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index 9e25b51..4959cee 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -26,24 +26,25 @@ public class LearningInfoPanel extends JPanel { private JCheckBox drawEnvironmentCheckbox; private JTextField learnMoreEpisodesInput; private JButton learnMoreEpisodesButton; + private JButton showQTableButton; - public LearningInfoPanel(Learning learning, ViewListener viewListener){ + public LearningInfoPanel(Learning learning, ViewListener viewListener) { this.learning = learning; setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); policyLabel = new JLabel(); discountLabel = new JLabel(); delayLabel = new JLabel(); - if(learning instanceof Episodic){ + if(learning instanceof Episodic) { episodeLabel = new JLabel(); add(episodeLabel); } - delaySlider = new JSlider(0,1000, learning.getDelay()); + delaySlider = new JSlider(0, 1000, learning.getDelay()); delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); add(policyLabel); add(discountLabel); - if(learning.getPolicy() instanceof EpsilonPolicy){ + if(learning.getPolicy() instanceof EpsilonPolicy) { epsilonLabel = new JLabel(); - epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100); + epsilonSlider = new JSlider(0, 100, (int) ((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100); epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); add(epsilonLabel); add(epsilonSlider); @@ -51,7 +52,7 @@ public class LearningInfoPanel extends JPanel { toggleFastLearningButton = new JButton("Enable fast-learn"); fastLearning = false; - toggleFastLearningButton.addActionListener(e->{ + toggleFastLearningButton.addActionListener(e -> { fastLearning = !fastLearning; delaySlider.setEnabled(!fastLearning); epsilonSlider.setEnabled(!fastLearning); @@ -71,10 +72,10 @@ public class LearningInfoPanel extends JPanel { if(learning instanceof EpisodicLearning) { learnMoreEpisodesInput = new JTextField(); - learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20)); + learnMoreEpisodesInput.setMaximumSize(new Dimension(200, 20)); learnMoreEpisodesButton = new JButton("Learn More Episodes"); learnMoreEpisodesButton.addActionListener(e -> { - if (Util.isNumeric(learnMoreEpisodesInput.getText())) { + if(Util.isNumeric(learnMoreEpisodesInput.getText())) { viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText())); } else { learnMoreEpisodesInput.setText(""); @@ -83,9 +84,14 @@ public class LearningInfoPanel extends JPanel { add(learnMoreEpisodesInput); add(learnMoreEpisodesButton); } + showQTableButton = new JButton("Show Q-Table"); + showQTableButton.addActionListener(e -> { + viewListener.onShowQTable(); + }); add(drawEnvironmentCheckbox); add(smoothGraphCheckbox); add(last100Checkbox); + add(showQTableButton); refreshLabels(); setVisible(true); } @@ -93,17 +99,17 @@ public class LearningInfoPanel extends JPanel { public void refreshLabels() { policyLabel.setText("Policy: " + learning.getPolicy().getClass()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); - if(learning instanceof Episodic){ - episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() + - "\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() + - "\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond()); + if(learning instanceof Episodic) { + episodeLabel.setText("Episode: " + ((Episodic) (learning)).getCurrentEpisode() + + "\t Episodes to go: " + ((Episodic) (learning)).getEpisodesToGo() + + "\t Eps/Sec: " + ((Episodic) (learning)).getEpisodesPerSecond()); } - if (learning.getPolicy() instanceof EpsilonPolicy) { + if(learning.getPolicy() instanceof EpsilonPolicy) { epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon()); - epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100)); + epsilonSlider.setValue((int) (((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100)); } delayLabel.setText("Delay (ms): " + learning.getDelay()); - if(delaySlider.isEnabled()){ + if(delaySlider.isEnabled()) { delaySlider.setValue(learning.getDelay()); } toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning"); @@ -112,11 +118,12 @@ public class LearningInfoPanel extends JPanel { protected boolean isSmoothenGraphSelected() { return smoothGraphCheckbox.isSelected(); } - protected boolean isLast100Selected(){ + + protected boolean isLast100Selected() { return last100Checkbox.isSelected(); } - protected boolean isDrawEnvironmentSelected(){ + protected boolean isDrawEnvironmentSelected() { return drawEnvironmentCheckbox.isSelected(); } } diff --git a/src/main/java/core/gui/LearningView.java b/src/main/java/core/gui/LearningView.java index 92e3d1f..fc39426 100644 --- a/src/main/java/core/gui/LearningView.java +++ b/src/main/java/core/gui/LearningView.java @@ -8,5 +8,7 @@ import java.util.List; public interface LearningView { void repaintEnvironment(); void updateLearningInfoPanel(); + void updateQTable(); void updateRewardGraph(final List rewardHistory); + void showQTableFrame(); } diff --git a/src/main/java/core/gui/QTableFrame.java b/src/main/java/core/gui/QTableFrame.java new file mode 100644 index 0000000..c820921 --- /dev/null +++ b/src/main/java/core/gui/QTableFrame.java @@ -0,0 +1,57 @@ +package core.gui; + +import core.State; +import core.StateActionTable; + +import javax.swing.*; +import java.awt.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class QTableFrame extends JFrame { + private JLabel stateCountLabel; + private StateActionTable stateActionTable; + private List> rows; + private JPanel areaWrapper; + + public QTableFrame(StateActionTable stateActionTable) { + super("Q-Table"); + this.stateActionTable = stateActionTable; + rows = new ArrayList<>(10); + setDefaultCloseOperation(WindowConstants.HIDE_ON_CLOSE); + setLayout(new BorderLayout()); + setPreferredSize(new Dimension(500, 500)); + stateCountLabel = new JLabel(); + add(BorderLayout.NORTH, stateCountLabel); + areaWrapper = new JPanel(); + areaWrapper.setLayout(new BoxLayout(areaWrapper, BoxLayout.Y_AXIS)); + for(int i = 0; i < 10; ++i) { + StateActionRow a = new StateActionRow<>(); + rows.add(a); + areaWrapper.add(a); + } + add(BorderLayout.CENTER, areaWrapper); + setVisible(false); + pack(); + } + + private void refreshAllTextAreas(){ + for(StateActionRow row : rows){ + row.refreshLabels(); + } + } + protected void refreshQTable() { + System.out.println("ref"); + int stateCount = stateActionTable.getStateCount(); + stateCountLabel.setText("Total states: " + stateCount); + int idx = -1; + for(Map.Entry> entry : stateActionTable.getFirstStateEntriesForView()) { + if(++idx > rows.size() -1) break; + StateActionRow row = rows.get(idx); + row.setState(entry.getKey()); + row.setActionValues(entry.getValue()); + } + refreshAllTextAreas(); + } +} diff --git a/src/main/java/core/gui/StateActionRow.java b/src/main/java/core/gui/StateActionRow.java new file mode 100644 index 0000000..53e8e36 --- /dev/null +++ b/src/main/java/core/gui/StateActionRow.java @@ -0,0 +1,55 @@ +package core.gui; + +import core.State; +import lombok.Setter; + +import javax.swing.*; +import java.awt.*; +import java.awt.event.MouseAdapter; +import java.awt.event.MouseEvent; +import java.util.Map; + +@Setter +public class StateActionRow extends JTextArea { + private State state; + private Map actionValues; + + public StateActionRow(){ + this.state = null; + this.actionValues = null; + setMaximumSize(new Dimension(600, 100)); + setEditable(false); + addMouseListener(new MouseAdapter() { + @Override + public void mousePressed(MouseEvent e) { + super.mousePressed(e); + showState(); + } + }); + } + + protected void refreshLabels(){ + if(state == null || actionValues == null) return; + System.out.println("refreshing"); + StringBuilder sb = new StringBuilder(state.toString()).append("\n"); + for(Map.Entry actionValue: actionValues.entrySet()){ + sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n"); + } + setText(sb.toString()); + } + + private void showState() { + if(state != null && state instanceof Visualizable){ + new JFrame() { + { + JComponent stateComponent = ((Visualizable)state).visualize(); + setPreferredSize(stateComponent.getPreferredSize()); + setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + add(stateComponent); + pack(); + setVisible(true); + } + }; + } + } +} diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index dce75c6..83c52f9 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -4,7 +4,8 @@ import core.Environment; import core.algo.Learning; import core.listener.ViewListener; import lombok.Getter; -import org.javatuples.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.knowm.xchart.QuickChart; import org.knowm.xchart.XChartPanel; import org.knowm.xchart.XYChart; @@ -16,7 +17,7 @@ import java.io.File; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; -public class View implements LearningView{ +public class View implements LearningView { private Learning learning; private Environment environment; @Getter @@ -26,6 +27,7 @@ public class View implements LearningView{ @Getter private JFrame mainFrame; private JFrame environmentFrame; + private QTableFrame qTableFrame; private XChartPanel rewardChartPanel; private ViewListener viewListener; private JMenuBar menuBar; @@ -36,11 +38,12 @@ public class View implements LearningView{ this.environment = environment; this.viewListener = viewListener; initMainFrame(); + initQTableFrame(); } private void initMainFrame() { mainFrame = new JFrame(); - mainFrame.setPreferredSize(new Dimension(1280, 720)); + mainFrame.setPreferredSize(new Dimension(1000, 400)); mainFrame.setLayout(new BorderLayout()); menuBar = new JMenuBar(); fileMenu = new JMenu("File"); @@ -52,7 +55,7 @@ public class View implements LearningView{ fc.setCurrentDirectory(new File(System.getProperty("user.dir"))); int returnVal = fc.showOpenDialog(mainFrame); - if (returnVal == JFileChooser.APPROVE_OPTION) { + if(returnVal == JFileChooser.APPROVE_OPTION) { viewListener.onLoadState(fc.getSelectedFile().toString()); } } @@ -62,7 +65,7 @@ public class View implements LearningView{ @Override public void actionPerformed(ActionEvent e) { String fileName = JOptionPane.showInputDialog("Enter file name", "path/to/file"); - if(fileName != null){ + if(fileName != null) { viewListener.onSaveState(fileName); } } @@ -78,7 +81,7 @@ public class View implements LearningView{ mainFrame.pack(); mainFrame.setVisible(true); - if (environment instanceof Visualizable) { + if(environment instanceof Visualizable) { environmentFrame = new JFrame() { { add(((Visualizable) environment).visualize()); @@ -86,9 +89,21 @@ public class View implements LearningView{ setVisible(true); } }; - } } + private void initQTableFrame(){ + qTableFrame = new QTableFrame<>(learning.getStateActionTable()); + } + + @Override + public void updateQTable() { + qTableFrame.refreshQTable(); + } + + public void showQTableFrame(){ + updateQTable(); + qTableFrame.setVisible(true); + } private void initLearningInfoPanel() { learningInfoPanel = new LearningInfoPanel(learning, viewListener); @@ -109,32 +124,21 @@ public class View implements LearningView{ rewardChartPanel.setPreferredSize(new Dimension(300, 300)); } - public void showState(Visualizable state) { - new JFrame() { - { - JComponent stateComponent = state.visualize(); - setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight())); - add(stateComponent); - setVisible(true); - } - }; - } - public void updateRewardGraph(final List rewardHistory) { List xValues; List yValues; - if(learningInfoPanel.isLast100Selected()){ + if(learningInfoPanel.isLast100Selected()) { yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size())); xValues = new CopyOnWriteArrayList<>(); - for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i , List> XYvalues = smoothenGraph(rewardHistory); - xValues = XYvalues.getValue0(); - yValues = XYvalues.getValue1(); - }else{ + xValues = XYvalues.getKey(); + yValues = XYvalues.getValue(); + } else { xValues = null; yValues = rewardHistory; } @@ -145,37 +149,37 @@ public class View implements LearningView{ rewardChartPanel.repaint(); } - private Pair, List> smoothenGraph(List original){ + private Pair, List> smoothenGraph(List original) { int totalXPoints = 100; List xValues = new CopyOnWriteArrayList<>(); List tmp = new CopyOnWriteArrayList<>(); int meanBatch = original.size() / totalXPoints; - if(meanBatch < 1){ + if(meanBatch < 1) { meanBatch = 1; } int idx = 0; int batchIdx = 0; double batchSum = 0; - for(Double x: original) { + for(Double x : original) { ++idx; batchSum += x; - if (idx == 1 || ++batchIdx % meanBatch == 0) { + if(idx == 1 || ++batchIdx % meanBatch == 0) { tmp.add(batchSum / meanBatch); xValues.add(idx); batchSum = 0; } } - return new Pair<>(xValues, tmp); + return new ImmutablePair<>(xValues, tmp); } public void updateLearningInfoPanel() { this.learningInfoPanel.refreshLabels(); } - public void repaintEnvironment(){ - if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) { + public void repaintEnvironment() { + if(environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) { environmentFrame.repaint(); } } diff --git a/src/main/java/core/listener/ViewListener.java b/src/main/java/core/listener/ViewListener.java index 7651abe..487dc5f 100644 --- a/src/main/java/core/listener/ViewListener.java +++ b/src/main/java/core/listener/ViewListener.java @@ -12,4 +12,5 @@ public interface ViewListener { void onLearnMoreEpisodes(int nrOfEpisodes); void onLoadState(String fileName); void onSaveState(String fileName); + void onShowQTable(); } diff --git a/src/main/java/evironment/jumpingDino/Config.java b/src/main/java/evironment/jumpingDino/Config.java index 29af520..03dca1e 100644 --- a/src/main/java/evironment/jumpingDino/Config.java +++ b/src/main/java/evironment/jumpingDino/Config.java @@ -1,8 +1,8 @@ package evironment.jumpingDino; public class Config { - public static final int FRAME_WIDTH = 1280; - public static final int FRAME_HEIGHT = 720; + public static final int FRAME_WIDTH = 800; + public static final int FRAME_HEIGHT = 300; public static final int GROUND_Y = 50; public static final int DINO_STARTING_X = 50; public static final int DINO_SIZE = 50; diff --git a/src/main/java/evironment/jumpingDino/DinoState.java b/src/main/java/evironment/jumpingDino/DinoState.java index 699d99e..8f9a783 100644 --- a/src/main/java/evironment/jumpingDino/DinoState.java +++ b/src/main/java/evironment/jumpingDino/DinoState.java @@ -1,16 +1,20 @@ package evironment.jumpingDino; import core.State; +import core.gui.Visualizable; import lombok.AllArgsConstructor; import lombok.Getter; +import javax.swing.*; +import java.awt.*; import java.io.Serializable; import java.util.Objects; @AllArgsConstructor @Getter -public class DinoState implements State, Serializable { +public class DinoState implements State, Serializable, Visualizable { private int xDistanceToObstacle; + protected final double scale = 0.5; @Override public String toString() { @@ -31,4 +35,31 @@ public class DinoState implements State, Serializable { public int hashCode() { return Objects.hash(xDistanceToObstacle); } + + @Override + public JComponent visualize() { + return new JComponent() { + { + setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT))); + setVisible(true); + } + + @Override + protected void paintComponent(Graphics g) { + super.paintComponents(g); + drawObjects(g); + } + }; + } + + public void drawObjects(Graphics g){ + g.setColor(Color.BLACK); + g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2); + + g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE)); + g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) )); + + g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE)); + + } } diff --git a/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java b/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java index a110f38..4e3dd08 100644 --- a/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java +++ b/src/main/java/evironment/jumpingDino/DinoStateWithSpeed.java @@ -1,11 +1,13 @@ package evironment.jumpingDino; +import core.gui.Visualizable; import lombok.Getter; +import java.awt.*; import java.util.Objects; @Getter -public class DinoStateWithSpeed extends DinoState{ +public class DinoStateWithSpeed extends DinoState implements Visualizable { private int obstacleSpeed; public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) { @@ -33,4 +35,10 @@ public class DinoStateWithSpeed extends DinoState{ public int hashCode() { return Objects.hash(super.hashCode(), getObstacleSpeed()); } + + @Override + public void drawObjects(Graphics g) { + super.drawObjects(g); + g.drawString("Speed: " + obstacleSpeed, (int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)) ); + } } diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 7cdd445..84e3c7e 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -16,7 +16,7 @@ public class JumpingDino { Method.MC_ONPOLICY_EGREEDY, DinoAction.values()); - rl.setDelay(0); + rl.setDelay(100); rl.setDiscountFactor(1f); rl.setEpsilon(0.15f); rl.setNrOfEpisodes(100000);