add QTableFrame and clickable states that display a gui

- remove org.javaTuple in favour of org.apache.common for tuples and circleQueue
- remove ViewListener from non-GUI Controller
- stateActionTable saves the last 10 states that changed. They will get displayed in QTable Frame
in JTextAreas
This commit is contained in:
Jan Löwenstrom 2020-01-01 23:54:18 +01:00
parent a8f8af1102
commit f4f1f7bd37
16 changed files with 334 additions and 136 deletions

View File

@ -20,9 +20,10 @@ dependencies {
testCompile group: 'junit', name: 'junit', version: '4.12'
compileOnly 'org.projectlombok:lombok:1.18.10'
annotationProcessor 'org.projectlombok:lombok:1.18.10'
compile 'org.javatuples:javatuples:1.2'
// https://mvnrepository.com/artifact/org.javatuples/javatuples
compile group: 'org.javatuples', name: 'javatuples', version: '1.2'
// https://mvnrepository.com/artifact/org.apache.commons/commons-lang3
compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.0'
// https://mvnrepository.com/artifact/org.apache.commons/commons-collections4
compile group: 'org.apache.commons', name: 'commons-collections4', version: '4.1'
}

View File

@ -1,9 +1,9 @@
package core;
import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.*;
import org.apache.commons.collections4.queue.CircularFifoQueue;
/**
* Premise: All states have the complete action space
*/
@ -11,10 +11,12 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
private final Map<State, Map<A, Double>> table;
private DiscreteActionSpace<A> discreteActionSpace;
private Queue<Map.Entry<State, Map<A, Double>>> latestChanges;
public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
table = new LinkedHashMap<>();
this.discreteActionSpace = discreteActionSpace;
latestChanges = new CircularFifoQueue<>(10);
}
/**
@ -57,6 +59,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
actionValues = createDefaultActionValues();
table.put(state, actionValues);
}
latestChanges.add(new AbstractMap.SimpleEntry<>(state, actionValues));
actionValues.put(action, value);
}
@ -72,6 +75,11 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
return table.get(state);
}
@Override
public Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView() {
return latestChanges;
}
/**
* @return Map with initial values for every available action
*/

View File

@ -1,6 +1,7 @@
package core;
import java.util.Map;
import java.util.Queue;
/**
* Q-Table which saves all seen states, all available actions for each state
@ -15,4 +16,6 @@ public interface StateActionTable<A extends Enum> {
void setValue(State state, A action, double value);
int getStateCount();
Map<A, Double> getActionValues(State state);
Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView();
}

View File

@ -3,7 +3,8 @@ package core.algo.mc;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import org.javatuples.Pair;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
@ -80,7 +81,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
for (StepResult<A> sr : episode) {
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
}
//System.out.println("stateActionPairs " + stateActionPairs.size());
@ -88,7 +89,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
int firstOccurenceIndex = 0;
// find first occurance of state action pair
for (StepResult<A> sr : episode) {
if (stateActionPair.getValue0().equals(sr.getState()) && stateActionPair.getValue1().equals(sr.getAction())) {
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
break;
}
firstOccurenceIndex++;
@ -102,7 +103,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getValue0(), stateActionPair.getValue1(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
}

View File

@ -15,10 +15,9 @@ import lombok.Setter;
import javax.swing.*;
import java.io.*;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener {
public class RLController<A extends Enum> implements LearningListener {
protected final String folderPrefix = "learningStates" + File.separator;
protected Environment<A> environment;
protected DiscreteActionSpace<A> discreteActionSpace;
@ -37,15 +36,15 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
protected int prevDelay;
protected volatile boolean printNextEpisode;
public RLController(Environment<A> env, Method method, A... actions){
public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
printNextEpisode = true;
}
public void start(){
switch (method){
public void start() {
switch(method) {
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
@ -60,13 +59,13 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
initLearning();
}
protected void initListeners(){
protected void initListeners() {
learning.addListener(this);
new Thread(() -> {
while (true){
while(true) {
printNextEpisode = true;
try {
Thread.sleep(30*1000);
Thread.sleep(30 * 1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
@ -74,35 +73,42 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}).start();
}
private void initLearning(){
if(learning instanceof EpisodicLearning){
private void initLearning() {
if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
((EpisodicLearning) learning).learn(nrOfEpisodes);
} else {
learning.learn();
}
}
/*************************************************
** VIEW LISTENERS **
*************************************************/
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes){
if(learning instanceof EpisodicLearning){
protected void changeLearningDelay(int delay) {
learning.setDelay(delay);
}
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
}
@Override
public void onLoadState(String fileName) {
protected void changeEpsilon(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy) {
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
} else {
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
}
protected void saveState(String fileName) {
FileInputStream fis;
ObjectInputStream in;
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
System.out.println("interrupt" + Thread.currentThread().getId());
System.out.println("interrup" + Thread.currentThread().getId());
learning.interruptLearning();
learning.load(in);
in.close();
@ -111,46 +117,26 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}
}
@Override
public void onSaveState(String fileName) {
protected void loadState(String fileName) {
FileOutputStream fos;
ObjectOutputStream out;
try{
fos = new FileOutputStream(folderPrefix + fileName);
try {
fos = new FileOutputStream(folderPrefix + fileName);
out = new ObjectOutputStream(fos);
learning.interruptLearning();
learning.save(out);
out.close();
}catch (IOException e){
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
}else{
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
}
@Override
public void onDelayChange(int delay) {
changeLearningDelay(delay);
}
protected void changeLearningDelay(int delay){
learning.setDelay(delay);
}
@Override
public void onFastLearnChange(boolean fastLearn) {
protected void changeFastLearning(boolean fastLearn) {
this.fastLearning = fastLearn;
if(fastLearn){
if(fastLearn) {
prevDelay = learning.getDelay();
changeLearningDelay(0);
}else{
} else {
changeLearningDelay(prevDelay);
}
}
@ -165,7 +151,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
@Override
public void onLearningEnd() {
System.out.println("Learning finished");
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
@ -176,9 +161,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode){
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}
@ -192,22 +177,22 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
** SETTERS **
*************************************************/
private void setEnvironment(Environment<A> environment){
if(environment == null){
private void setEnvironment(Environment<A> environment) {
if(environment == null) {
throw new IllegalArgumentException("Environment cannot be null");
}
this.environment = environment;
}
private void setMethod(Method method){
if(method == null){
private void setMethod(Method method) {
if(method == null) {
throw new IllegalArgumentException("Method cannot be null");
}
this.method = method;
}
private void setAllowedActions(A[] actions){
if(actions == null || actions.length == 0){
private void setAllowedActions(A[] actions) {
if(actions == null || actions.length == 0) {
throw new IllegalArgumentException("There has to be at least one action");
}
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);

View File

@ -1,14 +1,16 @@
package core.controller;
import core.Environment;
import core.algo.EpisodicLearning;
import core.algo.Method;
import core.gui.LearningView;
import core.gui.View;
import core.listener.ViewListener;
import javax.swing.*;
import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> {
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
private LearningView learningView;
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
@ -23,21 +25,41 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
});
}
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes) {
super.onLearnMoreEpisodes(nrOfEpisodes);
learningView.updateLearningInfoPanel();
}
/*************************************************
** View LISTENERS **
*************************************************/
@Override
public void onLoadState(String fileName) {
super.onLoadState(fileName);
super.loadState(fileName);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onSaveState(String fileName) {
super.saveState(fileName);
}
@Override
public void onShowQTable() {
learningView.showQTableFrame();
}
@Override
public void onEpsilonChange(float epsilon) {
super.onEpsilonChange(epsilon);
super.changeEpsilon(epsilon);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onDelayChange(int delay) {
super.changeLearningDelay(delay);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onFastLearnChange(boolean isFastLearn) {
super.changeFastLearning(isFastLearn);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@ -48,17 +70,23 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
}
@Override
public void onLearningEnd() {
super.onLearningEnd();
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
public void onLearnMoreEpisodes(int nrOfEpisodes) {
super.learnMoreEpisodes(nrOfEpisodes);
learningView.updateLearningInfoPanel();
}
/*************************************************
** LEARNING LISTENERS **
*************************************************/
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
super.onEpisodeEnd(rewardHistory);
SwingUtilities.invokeLater(() -> {
if (!fastLearning) {
if(!fastLearning) {
learningView.updateRewardGraph(latestRewardsHistory);
learningView.updateQTable();
}
learningView.updateLearningInfoPanel();
});
@ -66,8 +94,15 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
@Override
public void onStepEnd() {
if (!fastLearning) {
if(!fastLearning) {
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
}
}
@Override
public void onLearningEnd() {
super.onLearningEnd();
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
}

View File

@ -26,24 +26,25 @@ public class LearningInfoPanel extends JPanel {
private JCheckBox drawEnvironmentCheckbox;
private JTextField learnMoreEpisodesInput;
private JButton learnMoreEpisodesButton;
private JButton showQTableButton;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
public LearningInfoPanel(Learning learning, ViewListener viewListener) {
this.learning = learning;
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
policyLabel = new JLabel();
discountLabel = new JLabel();
delayLabel = new JLabel();
if(learning instanceof Episodic){
if(learning instanceof Episodic) {
episodeLabel = new JLabel();
add(episodeLabel);
}
delaySlider = new JSlider(0,1000, learning.getDelay());
delaySlider = new JSlider(0, 1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
if(learning.getPolicy() instanceof EpsilonPolicy){
if(learning.getPolicy() instanceof EpsilonPolicy) {
epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
epsilonSlider = new JSlider(0, 100, (int) ((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100);
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(epsilonLabel);
add(epsilonSlider);
@ -51,7 +52,7 @@ public class LearningInfoPanel extends JPanel {
toggleFastLearningButton = new JButton("Enable fast-learn");
fastLearning = false;
toggleFastLearningButton.addActionListener(e->{
toggleFastLearningButton.addActionListener(e -> {
fastLearning = !fastLearning;
delaySlider.setEnabled(!fastLearning);
epsilonSlider.setEnabled(!fastLearning);
@ -71,10 +72,10 @@ public class LearningInfoPanel extends JPanel {
if(learning instanceof EpisodicLearning) {
learnMoreEpisodesInput = new JTextField();
learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20));
learnMoreEpisodesInput.setMaximumSize(new Dimension(200, 20));
learnMoreEpisodesButton = new JButton("Learn More Episodes");
learnMoreEpisodesButton.addActionListener(e -> {
if (Util.isNumeric(learnMoreEpisodesInput.getText())) {
if(Util.isNumeric(learnMoreEpisodesInput.getText())) {
viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText()));
} else {
learnMoreEpisodesInput.setText("");
@ -83,9 +84,14 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesInput);
add(learnMoreEpisodesButton);
}
showQTableButton = new JButton("Show Q-Table");
showQTableButton.addActionListener(e -> {
viewListener.onShowQTable();
});
add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox);
add(last100Checkbox);
add(showQTableButton);
refreshLabels();
setVisible(true);
}
@ -93,17 +99,17 @@ public class LearningInfoPanel extends JPanel {
public void refreshLabels() {
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
if(learning instanceof Episodic){
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() +
"\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() +
"\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond());
if(learning instanceof Episodic) {
episodeLabel.setText("Episode: " + ((Episodic) (learning)).getCurrentEpisode() +
"\t Episodes to go: " + ((Episodic) (learning)).getEpisodesToGo() +
"\t Eps/Sec: " + ((Episodic) (learning)).getEpisodesPerSecond());
}
if (learning.getPolicy() instanceof EpsilonPolicy) {
if(learning.getPolicy() instanceof EpsilonPolicy) {
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
epsilonSlider.setValue((int) (((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
}
delayLabel.setText("Delay (ms): " + learning.getDelay());
if(delaySlider.isEnabled()){
if(delaySlider.isEnabled()) {
delaySlider.setValue(learning.getDelay());
}
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
@ -112,11 +118,12 @@ public class LearningInfoPanel extends JPanel {
protected boolean isSmoothenGraphSelected() {
return smoothGraphCheckbox.isSelected();
}
protected boolean isLast100Selected(){
protected boolean isLast100Selected() {
return last100Checkbox.isSelected();
}
protected boolean isDrawEnvironmentSelected(){
protected boolean isDrawEnvironmentSelected() {
return drawEnvironmentCheckbox.isSelected();
}
}

View File

@ -8,5 +8,7 @@ import java.util.List;
public interface LearningView {
void repaintEnvironment();
void updateLearningInfoPanel();
void updateQTable();
void updateRewardGraph(final List<Double> rewardHistory);
void showQTableFrame();
}

View File

@ -0,0 +1,57 @@
package core.gui;
import core.State;
import core.StateActionTable;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class QTableFrame<A extends Enum> extends JFrame {
private JLabel stateCountLabel;
private StateActionTable<A> stateActionTable;
private List<StateActionRow<A>> rows;
private JPanel areaWrapper;
public QTableFrame(StateActionTable<A> stateActionTable) {
super("Q-Table");
this.stateActionTable = stateActionTable;
rows = new ArrayList<>(10);
setDefaultCloseOperation(WindowConstants.HIDE_ON_CLOSE);
setLayout(new BorderLayout());
setPreferredSize(new Dimension(500, 500));
stateCountLabel = new JLabel();
add(BorderLayout.NORTH, stateCountLabel);
areaWrapper = new JPanel();
areaWrapper.setLayout(new BoxLayout(areaWrapper, BoxLayout.Y_AXIS));
for(int i = 0; i < 10; ++i) {
StateActionRow<A> a = new StateActionRow<>();
rows.add(a);
areaWrapper.add(a);
}
add(BorderLayout.CENTER, areaWrapper);
setVisible(false);
pack();
}
private void refreshAllTextAreas(){
for(StateActionRow<A> row : rows){
row.refreshLabels();
}
}
protected void refreshQTable() {
System.out.println("ref");
int stateCount = stateActionTable.getStateCount();
stateCountLabel.setText("Total states: " + stateCount);
int idx = -1;
for(Map.Entry<State, Map<A, Double>> entry : stateActionTable.getFirstStateEntriesForView()) {
if(++idx > rows.size() -1) break;
StateActionRow<A> row = rows.get(idx);
row.setState(entry.getKey());
row.setActionValues(entry.getValue());
}
refreshAllTextAreas();
}
}

View File

@ -0,0 +1,55 @@
package core.gui;
import core.State;
import lombok.Setter;
import javax.swing.*;
import java.awt.*;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.util.Map;
@Setter
public class StateActionRow<A extends Enum> extends JTextArea {
private State state;
private Map<A, Double> actionValues;
public StateActionRow(){
this.state = null;
this.actionValues = null;
setMaximumSize(new Dimension(600, 100));
setEditable(false);
addMouseListener(new MouseAdapter() {
@Override
public void mousePressed(MouseEvent e) {
super.mousePressed(e);
showState();
}
});
}
protected void refreshLabels(){
if(state == null || actionValues == null) return;
System.out.println("refreshing");
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
}
setText(sb.toString());
}
private void showState() {
if(state != null && state instanceof Visualizable){
new JFrame() {
{
JComponent stateComponent = ((Visualizable)state).visualize();
setPreferredSize(stateComponent.getPreferredSize());
setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
add(stateComponent);
pack();
setVisible(true);
}
};
}
}
}

View File

@ -4,7 +4,8 @@ import core.Environment;
import core.algo.Learning;
import core.listener.ViewListener;
import lombok.Getter;
import org.javatuples.Pair;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XChartPanel;
import org.knowm.xchart.XYChart;
@ -16,7 +17,7 @@ import java.io.File;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
public class View<A extends Enum> implements LearningView{
public class View<A extends Enum> implements LearningView {
private Learning<A> learning;
private Environment<A> environment;
@Getter
@ -26,6 +27,7 @@ public class View<A extends Enum> implements LearningView{
@Getter
private JFrame mainFrame;
private JFrame environmentFrame;
private QTableFrame<A> qTableFrame;
private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener;
private JMenuBar menuBar;
@ -36,11 +38,12 @@ public class View<A extends Enum> implements LearningView{
this.environment = environment;
this.viewListener = viewListener;
initMainFrame();
initQTableFrame();
}
private void initMainFrame() {
mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setPreferredSize(new Dimension(1000, 400));
mainFrame.setLayout(new BorderLayout());
menuBar = new JMenuBar();
fileMenu = new JMenu("File");
@ -52,7 +55,7 @@ public class View<A extends Enum> implements LearningView{
fc.setCurrentDirectory(new File(System.getProperty("user.dir")));
int returnVal = fc.showOpenDialog(mainFrame);
if (returnVal == JFileChooser.APPROVE_OPTION) {
if(returnVal == JFileChooser.APPROVE_OPTION) {
viewListener.onLoadState(fc.getSelectedFile().toString());
}
}
@ -62,7 +65,7 @@ public class View<A extends Enum> implements LearningView{
@Override
public void actionPerformed(ActionEvent e) {
String fileName = JOptionPane.showInputDialog("Enter file name", "path/to/file");
if(fileName != null){
if(fileName != null) {
viewListener.onSaveState(fileName);
}
}
@ -78,7 +81,7 @@ public class View<A extends Enum> implements LearningView{
mainFrame.pack();
mainFrame.setVisible(true);
if (environment instanceof Visualizable) {
if(environment instanceof Visualizable) {
environmentFrame = new JFrame() {
{
add(((Visualizable) environment).visualize());
@ -86,9 +89,21 @@ public class View<A extends Enum> implements LearningView{
setVisible(true);
}
};
}
}
private void initQTableFrame(){
qTableFrame = new QTableFrame<>(learning.getStateActionTable());
}
@Override
public void updateQTable() {
qTableFrame.refreshQTable();
}
public void showQTableFrame(){
updateQTable();
qTableFrame.setVisible(true);
}
private void initLearningInfoPanel() {
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
@ -109,32 +124,21 @@ public class View<A extends Enum> implements LearningView{
rewardChartPanel.setPreferredSize(new Dimension(300, 300));
}
public void showState(Visualizable state) {
new JFrame() {
{
JComponent stateComponent = state.visualize();
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
add(stateComponent);
setVisible(true);
}
};
}
public void updateRewardGraph(final List<Double> rewardHistory) {
List<Integer> xValues;
List<Double> yValues;
if(learningInfoPanel.isLast100Selected()){
if(learningInfoPanel.isLast100Selected()) {
yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size()));
xValues = new CopyOnWriteArrayList<>();
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i <rewardHistory.size(); ++i){
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i < rewardHistory.size(); ++i) {
xValues.add(i);
}
}else{
if(learningInfoPanel.isSmoothenGraphSelected()){
} else {
if(learningInfoPanel.isSmoothenGraphSelected()) {
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
xValues = XYvalues.getValue0();
yValues = XYvalues.getValue1();
}else{
xValues = XYvalues.getKey();
yValues = XYvalues.getValue();
} else {
xValues = null;
yValues = rewardHistory;
}
@ -145,37 +149,37 @@ public class View<A extends Enum> implements LearningView{
rewardChartPanel.repaint();
}
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original){
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original) {
int totalXPoints = 100;
List<Integer> xValues = new CopyOnWriteArrayList<>();
List<Double> tmp = new CopyOnWriteArrayList<>();
int meanBatch = original.size() / totalXPoints;
if(meanBatch < 1){
if(meanBatch < 1) {
meanBatch = 1;
}
int idx = 0;
int batchIdx = 0;
double batchSum = 0;
for(Double x: original) {
for(Double x : original) {
++idx;
batchSum += x;
if (idx == 1 || ++batchIdx % meanBatch == 0) {
if(idx == 1 || ++batchIdx % meanBatch == 0) {
tmp.add(batchSum / meanBatch);
xValues.add(idx);
batchSum = 0;
}
}
return new Pair<>(xValues, tmp);
return new ImmutablePair<>(xValues, tmp);
}
public void updateLearningInfoPanel() {
this.learningInfoPanel.refreshLabels();
}
public void repaintEnvironment(){
if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
public void repaintEnvironment() {
if(environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
environmentFrame.repaint();
}
}

View File

@ -12,4 +12,5 @@ public interface ViewListener {
void onLearnMoreEpisodes(int nrOfEpisodes);
void onLoadState(String fileName);
void onSaveState(String fileName);
void onShowQTable();
}

View File

@ -1,8 +1,8 @@
package evironment.jumpingDino;
public class Config {
public static final int FRAME_WIDTH = 1280;
public static final int FRAME_HEIGHT = 720;
public static final int FRAME_WIDTH = 800;
public static final int FRAME_HEIGHT = 300;
public static final int GROUND_Y = 50;
public static final int DINO_STARTING_X = 50;
public static final int DINO_SIZE = 50;

View File

@ -1,16 +1,20 @@
package evironment.jumpingDino;
import core.State;
import core.gui.Visualizable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import javax.swing.*;
import java.awt.*;
import java.io.Serializable;
import java.util.Objects;
@AllArgsConstructor
@Getter
public class DinoState implements State, Serializable {
public class DinoState implements State, Serializable, Visualizable {
private int xDistanceToObstacle;
protected final double scale = 0.5;
@Override
public String toString() {
@ -31,4 +35,31 @@ public class DinoState implements State, Serializable {
public int hashCode() {
return Objects.hash(xDistanceToObstacle);
}
@Override
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponents(g);
drawObjects(g);
}
};
}
public void drawObjects(Graphics g){
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
}
}

View File

@ -1,11 +1,13 @@
package evironment.jumpingDino;
import core.gui.Visualizable;
import lombok.Getter;
import java.awt.*;
import java.util.Objects;
@Getter
public class DinoStateWithSpeed extends DinoState{
public class DinoStateWithSpeed extends DinoState implements Visualizable {
private int obstacleSpeed;
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
@ -33,4 +35,10 @@ public class DinoStateWithSpeed extends DinoState{
public int hashCode() {
return Objects.hash(super.hashCode(), getObstacleSpeed());
}
@Override
public void drawObjects(Graphics g) {
super.drawObjects(g);
g.drawString("Speed: " + obstacleSpeed, (int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)) );
}
}

View File

@ -16,7 +16,7 @@ public class JumpingDino {
Method.MC_ONPOLICY_EGREEDY,
DinoAction.values());
rl.setDelay(0);
rl.setDelay(100);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setNrOfEpisodes(100000);