add QTableFrame and clickable states that display a gui
- remove org.javaTuple in favour of org.apache.common for tuples and circleQueue - remove ViewListener from non-GUI Controller - stateActionTable saves the last 10 states that changed. They will get displayed in QTable Frame in JTextAreas
This commit is contained in:
parent
a8f8af1102
commit
f4f1f7bd37
|
@ -20,9 +20,10 @@ dependencies {
|
||||||
testCompile group: 'junit', name: 'junit', version: '4.12'
|
testCompile group: 'junit', name: 'junit', version: '4.12'
|
||||||
compileOnly 'org.projectlombok:lombok:1.18.10'
|
compileOnly 'org.projectlombok:lombok:1.18.10'
|
||||||
annotationProcessor 'org.projectlombok:lombok:1.18.10'
|
annotationProcessor 'org.projectlombok:lombok:1.18.10'
|
||||||
compile 'org.javatuples:javatuples:1.2'
|
// https://mvnrepository.com/artifact/org.apache.commons/commons-lang3
|
||||||
// https://mvnrepository.com/artifact/org.javatuples/javatuples
|
compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.0'
|
||||||
compile group: 'org.javatuples', name: 'javatuples', version: '1.2'
|
// https://mvnrepository.com/artifact/org.apache.commons/commons-collections4
|
||||||
|
compile group: 'org.apache.commons', name: 'commons-collections4', version: '4.1'
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.*;
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
|
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||||
/**
|
/**
|
||||||
* Premise: All states have the complete action space
|
* Premise: All states have the complete action space
|
||||||
*/
|
*/
|
||||||
|
@ -11,10 +11,12 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
||||||
|
|
||||||
private final Map<State, Map<A, Double>> table;
|
private final Map<State, Map<A, Double>> table;
|
||||||
private DiscreteActionSpace<A> discreteActionSpace;
|
private DiscreteActionSpace<A> discreteActionSpace;
|
||||||
|
private Queue<Map.Entry<State, Map<A, Double>>> latestChanges;
|
||||||
|
|
||||||
public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
|
public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
|
||||||
table = new LinkedHashMap<>();
|
table = new LinkedHashMap<>();
|
||||||
this.discreteActionSpace = discreteActionSpace;
|
this.discreteActionSpace = discreteActionSpace;
|
||||||
|
latestChanges = new CircularFifoQueue<>(10);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -57,6 +59,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
||||||
actionValues = createDefaultActionValues();
|
actionValues = createDefaultActionValues();
|
||||||
table.put(state, actionValues);
|
table.put(state, actionValues);
|
||||||
}
|
}
|
||||||
|
latestChanges.add(new AbstractMap.SimpleEntry<>(state, actionValues));
|
||||||
actionValues.put(action, value);
|
actionValues.put(action, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,6 +75,11 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
||||||
return table.get(state);
|
return table.get(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView() {
|
||||||
|
return latestChanges;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return Map with initial values for every available action
|
* @return Map with initial values for every available action
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Queue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Q-Table which saves all seen states, all available actions for each state
|
* Q-Table which saves all seen states, all available actions for each state
|
||||||
|
@ -15,4 +16,6 @@ public interface StateActionTable<A extends Enum> {
|
||||||
void setValue(State state, A action, double value);
|
void setValue(State state, A action, double value);
|
||||||
int getStateCount();
|
int getStateCount();
|
||||||
Map<A, Double> getActionValues(State state);
|
Map<A, Double> getActionValues(State state);
|
||||||
|
|
||||||
|
Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView();
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,8 @@ package core.algo.mc;
|
||||||
import core.*;
|
import core.*;
|
||||||
import core.algo.EpisodicLearning;
|
import core.algo.EpisodicLearning;
|
||||||
import core.policy.EpsilonGreedyPolicy;
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
import org.javatuples.Pair;
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.ObjectInputStream;
|
import java.io.ObjectInputStream;
|
||||||
|
@ -80,7 +81,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||||
|
|
||||||
for (StepResult<A> sr : episode) {
|
for (StepResult<A> sr : episode) {
|
||||||
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
|
||||||
}
|
}
|
||||||
|
|
||||||
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||||
|
@ -88,7 +89,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
int firstOccurenceIndex = 0;
|
int firstOccurenceIndex = 0;
|
||||||
// find first occurance of state action pair
|
// find first occurance of state action pair
|
||||||
for (StepResult<A> sr : episode) {
|
for (StepResult<A> sr : episode) {
|
||||||
if (stateActionPair.getValue0().equals(sr.getState()) && stateActionPair.getValue1().equals(sr.getAction())) {
|
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
firstOccurenceIndex++;
|
firstOccurenceIndex++;
|
||||||
|
@ -102,7 +103,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
// if the key does not exists, it will create a new entry with G as default value
|
// if the key does not exists, it will create a new entry with G as default value
|
||||||
returnSum.merge(stateActionPair, G, Double::sum);
|
returnSum.merge(stateActionPair, G, Double::sum);
|
||||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||||
stateActionTable.setValue(stateActionPair.getValue0(), stateActionPair.getValue1(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,9 @@ import lombok.Setter;
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
|
|
||||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
|
||||||
|
public class RLController<A extends Enum> implements LearningListener {
|
||||||
protected final String folderPrefix = "learningStates" + File.separator;
|
protected final String folderPrefix = "learningStates" + File.separator;
|
||||||
protected Environment<A> environment;
|
protected Environment<A> environment;
|
||||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||||
|
@ -83,11 +82,11 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*************************************************
|
protected void changeLearningDelay(int delay) {
|
||||||
** VIEW LISTENERS **
|
learning.setDelay(delay);
|
||||||
*************************************************/
|
}
|
||||||
@Override
|
|
||||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
protected void learnMoreEpisodes(int nrOfEpisodes) {
|
||||||
if(learning instanceof EpisodicLearning) {
|
if(learning instanceof EpisodicLearning) {
|
||||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||||
} else {
|
} else {
|
||||||
|
@ -95,14 +94,21 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
protected void changeEpsilon(float epsilon) {
|
||||||
public void onLoadState(String fileName) {
|
if(learning.getPolicy() instanceof EpsilonPolicy) {
|
||||||
|
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||||
|
} else {
|
||||||
|
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void saveState(String fileName) {
|
||||||
FileInputStream fis;
|
FileInputStream fis;
|
||||||
ObjectInputStream in;
|
ObjectInputStream in;
|
||||||
try {
|
try {
|
||||||
fis = new FileInputStream(fileName);
|
fis = new FileInputStream(fileName);
|
||||||
in = new ObjectInputStream(fis);
|
in = new ObjectInputStream(fis);
|
||||||
System.out.println("interrupt" + Thread.currentThread().getId());
|
System.out.println("interrup" + Thread.currentThread().getId());
|
||||||
learning.interruptLearning();
|
learning.interruptLearning();
|
||||||
learning.load(in);
|
learning.load(in);
|
||||||
in.close();
|
in.close();
|
||||||
|
@ -111,8 +117,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
protected void loadState(String fileName) {
|
||||||
public void onSaveState(String fileName) {
|
|
||||||
FileOutputStream fos;
|
FileOutputStream fos;
|
||||||
ObjectOutputStream out;
|
ObjectOutputStream out;
|
||||||
try {
|
try {
|
||||||
|
@ -126,26 +131,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
protected void changeFastLearning(boolean fastLearn) {
|
||||||
public void onEpsilonChange(float epsilon) {
|
|
||||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
|
||||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
|
||||||
}else{
|
|
||||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onDelayChange(int delay) {
|
|
||||||
changeLearningDelay(delay);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void changeLearningDelay(int delay){
|
|
||||||
learning.setDelay(delay);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onFastLearnChange(boolean fastLearn) {
|
|
||||||
this.fastLearning = fastLearn;
|
this.fastLearning = fastLearn;
|
||||||
if(fastLearn) {
|
if(fastLearn) {
|
||||||
prevDelay = learning.getDelay();
|
prevDelay = learning.getDelay();
|
||||||
|
@ -165,7 +151,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
@Override
|
@Override
|
||||||
public void onLearningEnd() {
|
public void onLearningEnd() {
|
||||||
System.out.println("Learning finished");
|
System.out.println("Learning finished");
|
||||||
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,14 +1,16 @@
|
||||||
package core.controller;
|
package core.controller;
|
||||||
|
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.gui.LearningView;
|
import core.gui.LearningView;
|
||||||
import core.gui.View;
|
import core.gui.View;
|
||||||
|
import core.listener.ViewListener;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
|
||||||
private LearningView learningView;
|
private LearningView learningView;
|
||||||
|
|
||||||
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
||||||
|
@ -23,21 +25,41 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
/*************************************************
|
||||||
public void onLearnMoreEpisodes(int nrOfEpisodes) {
|
** View LISTENERS **
|
||||||
super.onLearnMoreEpisodes(nrOfEpisodes);
|
*************************************************/
|
||||||
learningView.updateLearningInfoPanel();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onLoadState(String fileName) {
|
public void onLoadState(String fileName) {
|
||||||
super.onLoadState(fileName);
|
super.loadState(fileName);
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onSaveState(String fileName) {
|
||||||
|
super.saveState(fileName);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onShowQTable() {
|
||||||
|
learningView.showQTableFrame();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onEpsilonChange(float epsilon) {
|
public void onEpsilonChange(float epsilon) {
|
||||||
super.onEpsilonChange(epsilon);
|
super.changeEpsilon(epsilon);
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onDelayChange(int delay) {
|
||||||
|
super.changeLearningDelay(delay);
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFastLearnChange(boolean isFastLearn) {
|
||||||
|
super.changeFastLearning(isFastLearn);
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,17 +70,23 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onLearningEnd() {
|
public void onLearnMoreEpisodes(int nrOfEpisodes) {
|
||||||
super.onLearningEnd();
|
super.learnMoreEpisodes(nrOfEpisodes);
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
learningView.updateLearningInfoPanel();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*************************************************
|
||||||
|
** LEARNING LISTENERS **
|
||||||
|
*************************************************/
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||||
super.onEpisodeEnd(rewardHistory);
|
super.onEpisodeEnd(rewardHistory);
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> {
|
||||||
if(!fastLearning) {
|
if(!fastLearning) {
|
||||||
learningView.updateRewardGraph(latestRewardsHistory);
|
learningView.updateRewardGraph(latestRewardsHistory);
|
||||||
|
learningView.updateQTable();
|
||||||
}
|
}
|
||||||
learningView.updateLearningInfoPanel();
|
learningView.updateLearningInfoPanel();
|
||||||
});
|
});
|
||||||
|
@ -70,4 +98,11 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
||||||
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onLearningEnd() {
|
||||||
|
super.onLearningEnd();
|
||||||
|
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ public class LearningInfoPanel extends JPanel {
|
||||||
private JCheckBox drawEnvironmentCheckbox;
|
private JCheckBox drawEnvironmentCheckbox;
|
||||||
private JTextField learnMoreEpisodesInput;
|
private JTextField learnMoreEpisodesInput;
|
||||||
private JButton learnMoreEpisodesButton;
|
private JButton learnMoreEpisodesButton;
|
||||||
|
private JButton showQTableButton;
|
||||||
|
|
||||||
public LearningInfoPanel(Learning learning, ViewListener viewListener) {
|
public LearningInfoPanel(Learning learning, ViewListener viewListener) {
|
||||||
this.learning = learning;
|
this.learning = learning;
|
||||||
|
@ -83,9 +84,14 @@ public class LearningInfoPanel extends JPanel {
|
||||||
add(learnMoreEpisodesInput);
|
add(learnMoreEpisodesInput);
|
||||||
add(learnMoreEpisodesButton);
|
add(learnMoreEpisodesButton);
|
||||||
}
|
}
|
||||||
|
showQTableButton = new JButton("Show Q-Table");
|
||||||
|
showQTableButton.addActionListener(e -> {
|
||||||
|
viewListener.onShowQTable();
|
||||||
|
});
|
||||||
add(drawEnvironmentCheckbox);
|
add(drawEnvironmentCheckbox);
|
||||||
add(smoothGraphCheckbox);
|
add(smoothGraphCheckbox);
|
||||||
add(last100Checkbox);
|
add(last100Checkbox);
|
||||||
|
add(showQTableButton);
|
||||||
refreshLabels();
|
refreshLabels();
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
}
|
}
|
||||||
|
@ -112,6 +118,7 @@ public class LearningInfoPanel extends JPanel {
|
||||||
protected boolean isSmoothenGraphSelected() {
|
protected boolean isSmoothenGraphSelected() {
|
||||||
return smoothGraphCheckbox.isSelected();
|
return smoothGraphCheckbox.isSelected();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean isLast100Selected() {
|
protected boolean isLast100Selected() {
|
||||||
return last100Checkbox.isSelected();
|
return last100Checkbox.isSelected();
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,5 +8,7 @@ import java.util.List;
|
||||||
public interface LearningView {
|
public interface LearningView {
|
||||||
void repaintEnvironment();
|
void repaintEnvironment();
|
||||||
void updateLearningInfoPanel();
|
void updateLearningInfoPanel();
|
||||||
|
void updateQTable();
|
||||||
void updateRewardGraph(final List<Double> rewardHistory);
|
void updateRewardGraph(final List<Double> rewardHistory);
|
||||||
|
void showQTableFrame();
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
package core.gui;
|
||||||
|
|
||||||
|
import core.State;
|
||||||
|
import core.StateActionTable;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class QTableFrame<A extends Enum> extends JFrame {
|
||||||
|
private JLabel stateCountLabel;
|
||||||
|
private StateActionTable<A> stateActionTable;
|
||||||
|
private List<StateActionRow<A>> rows;
|
||||||
|
private JPanel areaWrapper;
|
||||||
|
|
||||||
|
public QTableFrame(StateActionTable<A> stateActionTable) {
|
||||||
|
super("Q-Table");
|
||||||
|
this.stateActionTable = stateActionTable;
|
||||||
|
rows = new ArrayList<>(10);
|
||||||
|
setDefaultCloseOperation(WindowConstants.HIDE_ON_CLOSE);
|
||||||
|
setLayout(new BorderLayout());
|
||||||
|
setPreferredSize(new Dimension(500, 500));
|
||||||
|
stateCountLabel = new JLabel();
|
||||||
|
add(BorderLayout.NORTH, stateCountLabel);
|
||||||
|
areaWrapper = new JPanel();
|
||||||
|
areaWrapper.setLayout(new BoxLayout(areaWrapper, BoxLayout.Y_AXIS));
|
||||||
|
for(int i = 0; i < 10; ++i) {
|
||||||
|
StateActionRow<A> a = new StateActionRow<>();
|
||||||
|
rows.add(a);
|
||||||
|
areaWrapper.add(a);
|
||||||
|
}
|
||||||
|
add(BorderLayout.CENTER, areaWrapper);
|
||||||
|
setVisible(false);
|
||||||
|
pack();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void refreshAllTextAreas(){
|
||||||
|
for(StateActionRow<A> row : rows){
|
||||||
|
row.refreshLabels();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
protected void refreshQTable() {
|
||||||
|
System.out.println("ref");
|
||||||
|
int stateCount = stateActionTable.getStateCount();
|
||||||
|
stateCountLabel.setText("Total states: " + stateCount);
|
||||||
|
int idx = -1;
|
||||||
|
for(Map.Entry<State, Map<A, Double>> entry : stateActionTable.getFirstStateEntriesForView()) {
|
||||||
|
if(++idx > rows.size() -1) break;
|
||||||
|
StateActionRow<A> row = rows.get(idx);
|
||||||
|
row.setState(entry.getKey());
|
||||||
|
row.setActionValues(entry.getValue());
|
||||||
|
}
|
||||||
|
refreshAllTextAreas();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package core.gui;
|
||||||
|
|
||||||
|
import core.State;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.awt.event.MouseAdapter;
|
||||||
|
import java.awt.event.MouseEvent;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Setter
|
||||||
|
public class StateActionRow<A extends Enum> extends JTextArea {
|
||||||
|
private State state;
|
||||||
|
private Map<A, Double> actionValues;
|
||||||
|
|
||||||
|
public StateActionRow(){
|
||||||
|
this.state = null;
|
||||||
|
this.actionValues = null;
|
||||||
|
setMaximumSize(new Dimension(600, 100));
|
||||||
|
setEditable(false);
|
||||||
|
addMouseListener(new MouseAdapter() {
|
||||||
|
@Override
|
||||||
|
public void mousePressed(MouseEvent e) {
|
||||||
|
super.mousePressed(e);
|
||||||
|
showState();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void refreshLabels(){
|
||||||
|
if(state == null || actionValues == null) return;
|
||||||
|
System.out.println("refreshing");
|
||||||
|
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
|
||||||
|
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
|
||||||
|
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
|
||||||
|
}
|
||||||
|
setText(sb.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void showState() {
|
||||||
|
if(state != null && state instanceof Visualizable){
|
||||||
|
new JFrame() {
|
||||||
|
{
|
||||||
|
JComponent stateComponent = ((Visualizable)state).visualize();
|
||||||
|
setPreferredSize(stateComponent.getPreferredSize());
|
||||||
|
setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||||
|
add(stateComponent);
|
||||||
|
pack();
|
||||||
|
setVisible(true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,7 +4,8 @@ import core.Environment;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.listener.ViewListener;
|
import core.listener.ViewListener;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.javatuples.Pair;
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.knowm.xchart.QuickChart;
|
import org.knowm.xchart.QuickChart;
|
||||||
import org.knowm.xchart.XChartPanel;
|
import org.knowm.xchart.XChartPanel;
|
||||||
import org.knowm.xchart.XYChart;
|
import org.knowm.xchart.XYChart;
|
||||||
|
@ -26,6 +27,7 @@ public class View<A extends Enum> implements LearningView{
|
||||||
@Getter
|
@Getter
|
||||||
private JFrame mainFrame;
|
private JFrame mainFrame;
|
||||||
private JFrame environmentFrame;
|
private JFrame environmentFrame;
|
||||||
|
private QTableFrame<A> qTableFrame;
|
||||||
private XChartPanel<XYChart> rewardChartPanel;
|
private XChartPanel<XYChart> rewardChartPanel;
|
||||||
private ViewListener viewListener;
|
private ViewListener viewListener;
|
||||||
private JMenuBar menuBar;
|
private JMenuBar menuBar;
|
||||||
|
@ -36,11 +38,12 @@ public class View<A extends Enum> implements LearningView{
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
this.viewListener = viewListener;
|
this.viewListener = viewListener;
|
||||||
initMainFrame();
|
initMainFrame();
|
||||||
|
initQTableFrame();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initMainFrame() {
|
private void initMainFrame() {
|
||||||
mainFrame = new JFrame();
|
mainFrame = new JFrame();
|
||||||
mainFrame.setPreferredSize(new Dimension(1280, 720));
|
mainFrame.setPreferredSize(new Dimension(1000, 400));
|
||||||
mainFrame.setLayout(new BorderLayout());
|
mainFrame.setLayout(new BorderLayout());
|
||||||
menuBar = new JMenuBar();
|
menuBar = new JMenuBar();
|
||||||
fileMenu = new JMenu("File");
|
fileMenu = new JMenu("File");
|
||||||
|
@ -86,9 +89,21 @@ public class View<A extends Enum> implements LearningView{
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
private void initQTableFrame(){
|
||||||
|
qTableFrame = new QTableFrame<>(learning.getStateActionTable());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateQTable() {
|
||||||
|
qTableFrame.refreshQTable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void showQTableFrame(){
|
||||||
|
updateQTable();
|
||||||
|
qTableFrame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
private void initLearningInfoPanel() {
|
private void initLearningInfoPanel() {
|
||||||
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
|
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
|
||||||
|
@ -109,17 +124,6 @@ public class View<A extends Enum> implements LearningView{
|
||||||
rewardChartPanel.setPreferredSize(new Dimension(300, 300));
|
rewardChartPanel.setPreferredSize(new Dimension(300, 300));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void showState(Visualizable state) {
|
|
||||||
new JFrame() {
|
|
||||||
{
|
|
||||||
JComponent stateComponent = state.visualize();
|
|
||||||
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
|
|
||||||
add(stateComponent);
|
|
||||||
setVisible(true);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
public void updateRewardGraph(final List<Double> rewardHistory) {
|
public void updateRewardGraph(final List<Double> rewardHistory) {
|
||||||
List<Integer> xValues;
|
List<Integer> xValues;
|
||||||
List<Double> yValues;
|
List<Double> yValues;
|
||||||
|
@ -132,8 +136,8 @@ public class View<A extends Enum> implements LearningView{
|
||||||
} else {
|
} else {
|
||||||
if(learningInfoPanel.isSmoothenGraphSelected()) {
|
if(learningInfoPanel.isSmoothenGraphSelected()) {
|
||||||
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
|
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
|
||||||
xValues = XYvalues.getValue0();
|
xValues = XYvalues.getKey();
|
||||||
yValues = XYvalues.getValue1();
|
yValues = XYvalues.getValue();
|
||||||
} else {
|
} else {
|
||||||
xValues = null;
|
xValues = null;
|
||||||
yValues = rewardHistory;
|
yValues = rewardHistory;
|
||||||
|
@ -167,7 +171,7 @@ public class View<A extends Enum> implements LearningView{
|
||||||
batchSum = 0;
|
batchSum = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return new Pair<>(xValues, tmp);
|
return new ImmutablePair<>(xValues, tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void updateLearningInfoPanel() {
|
public void updateLearningInfoPanel() {
|
||||||
|
|
|
@ -12,4 +12,5 @@ public interface ViewListener {
|
||||||
void onLearnMoreEpisodes(int nrOfEpisodes);
|
void onLearnMoreEpisodes(int nrOfEpisodes);
|
||||||
void onLoadState(String fileName);
|
void onLoadState(String fileName);
|
||||||
void onSaveState(String fileName);
|
void onSaveState(String fileName);
|
||||||
|
void onShowQTable();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
package evironment.jumpingDino;
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
public class Config {
|
public class Config {
|
||||||
public static final int FRAME_WIDTH = 1280;
|
public static final int FRAME_WIDTH = 800;
|
||||||
public static final int FRAME_HEIGHT = 720;
|
public static final int FRAME_HEIGHT = 300;
|
||||||
public static final int GROUND_Y = 50;
|
public static final int GROUND_Y = 50;
|
||||||
public static final int DINO_STARTING_X = 50;
|
public static final int DINO_STARTING_X = 50;
|
||||||
public static final int DINO_SIZE = 50;
|
public static final int DINO_SIZE = 50;
|
||||||
|
|
|
@ -1,16 +1,20 @@
|
||||||
package evironment.jumpingDino;
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
import core.State;
|
import core.State;
|
||||||
|
import core.gui.Visualizable;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Getter
|
@Getter
|
||||||
public class DinoState implements State, Serializable {
|
public class DinoState implements State, Serializable, Visualizable {
|
||||||
private int xDistanceToObstacle;
|
private int xDistanceToObstacle;
|
||||||
|
protected final double scale = 0.5;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
|
@ -31,4 +35,31 @@ public class DinoState implements State, Serializable {
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(xDistanceToObstacle);
|
return Objects.hash(xDistanceToObstacle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JComponent visualize() {
|
||||||
|
return new JComponent() {
|
||||||
|
{
|
||||||
|
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
|
||||||
|
setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void paintComponent(Graphics g) {
|
||||||
|
super.paintComponents(g);
|
||||||
|
drawObjects(g);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
public void drawObjects(Graphics g){
|
||||||
|
g.setColor(Color.BLACK);
|
||||||
|
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
||||||
|
|
||||||
|
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||||
|
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
|
||||||
|
|
||||||
|
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package evironment.jumpingDino;
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
import core.gui.Visualizable;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
public class DinoStateWithSpeed extends DinoState{
|
public class DinoStateWithSpeed extends DinoState implements Visualizable {
|
||||||
private int obstacleSpeed;
|
private int obstacleSpeed;
|
||||||
|
|
||||||
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
|
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
|
||||||
|
@ -33,4 +35,10 @@ public class DinoStateWithSpeed extends DinoState{
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(super.hashCode(), getObstacleSpeed());
|
return Objects.hash(super.hashCode(), getObstacleSpeed());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void drawObjects(Graphics g) {
|
||||||
|
super.drawObjects(g);
|
||||||
|
g.drawString("Speed: " + obstacleSpeed, (int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)) );
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ public class JumpingDino {
|
||||||
Method.MC_ONPOLICY_EGREEDY,
|
Method.MC_ONPOLICY_EGREEDY,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(0);
|
rl.setDelay(100);
|
||||||
rl.setDiscountFactor(1f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(0.15f);
|
rl.setEpsilon(0.15f);
|
||||||
rl.setNrOfEpisodes(100000);
|
rl.setNrOfEpisodes(100000);
|
||||||
|
|
Loading…
Reference in New Issue