adopt MVC pattern and add real time graph interface

This commit is contained in:
Jan Löwenstrom 2019-12-18 16:48:24 +01:00
parent 7f18a66e98
commit e0160ca1df
18 changed files with 450 additions and 29 deletions

View File

@ -5,4 +5,4 @@
<profile default="true" name="Default" enabled="true" /> <profile default="true" name="Default" enabled="true" />
</annotationProcessing> </annotationProcessing>
</component> </component>
</project> </project>

View File

@ -13,6 +13,9 @@ repositories {
} }
dependencies { dependencies {
// https://mvnrepository.com/artifact/org.jfree/jfreechart
// https://mvnrepository.com/artifact/org.knowm.xchart/xchart
compile group: 'org.knowm.xchart', name: 'xchart', version: '3.2.2'
testCompile group: 'junit', name: 'junit', version: '4.12' testCompile group: 'junit', name: 'junit', version: '4.12'
compileOnly 'org.projectlombok:lombok:1.18.10' compileOnly 'org.projectlombok:lombok:1.18.10'
annotationProcessor 'org.projectlombok:lombok:1.18.10' annotationProcessor 'org.projectlombok:lombok:1.18.10'

View File

@ -3,7 +3,7 @@ package core;
import java.util.List; import java.util.List;
public interface DiscreteActionSpace<A extends Enum> { public interface DiscreteActionSpace<A extends Enum> {
int getNumberOfAction(); int getNumberOfActions();
void addAction(A a); void addAction(A a);
void addActions(A... as); void addActions(A... as);
List<A> getAllActions(); List<A> getAllActions();

View File

@ -0,0 +1,7 @@
package core;
public class LearningConfig {
public static final int DEFAULT_DELAY = 1;
public static final float DEFAULT_EPSILON = 0.1f;
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
}

View File

@ -32,7 +32,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
} }
@Override @Override
public int getNumberOfAction(){ public int getNumberOfActions(){
return actions.size(); return actions.size();
} }
} }

View File

@ -2,26 +2,62 @@ package core.algo;
import core.DiscreteActionSpace; import core.DiscreteActionSpace;
import core.Environment; import core.Environment;
import core.LearningConfig;
import core.StateActionTable; import core.StateActionTable;
import core.listener.LearningListener;
import core.policy.Policy; import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import javax.swing.*;
import java.util.HashSet;
import java.util.Set;
@Getter
public abstract class Learning<A extends Enum> { public abstract class Learning<A extends Enum> {
protected Policy<A> policy; protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace; protected DiscreteActionSpace<A> actionSpace;
protected StateActionTable<A> stateActionTable; protected StateActionTable<A> stateActionTable;
protected Environment<A> environment; protected Environment<A> environment;
protected float discountFactor; protected float discountFactor;
@Setter
protected float epsilon; protected float epsilon;
protected Set<LearningListener> learningListeners;
@Setter
protected int delay;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){ public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay){
this.environment = environment; this.environment = environment;
this.actionSpace = actionSpace; this.actionSpace = actionSpace;
this.discountFactor = discountFactor; this.discountFactor = discountFactor;
this.epsilon = epsilon; this.epsilon = epsilon;
} this.delay = delay;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){ learningListeners = new HashSet<>();
this(environment, actionSpace, 1.0f, 0.1f);
} }
public abstract void learn(int nrOfEpisodes, int delay); public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
}
public abstract void learn(int nrOfEpisodes);
public void addListener(LearningListener learningListener){
learningListeners.add(learningListener);
}
protected void dispatchEpisodeEnd(double sum){
for(LearningListener l: learningListeners) {
l.onEpisodeEnd(sum);
}
}
protected void dispatchEpisodeStart(){
for(LearningListener l: learningListeners){
l.onEpisodeStart();
}
}
} }

View File

@ -34,22 +34,21 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
} }
@Override @Override
public void learn(int nrOfEpisodes, int delay) { public void learn(int nrOfEpisodes) {
Map<Pair<State, A>, Double> returnSum = new HashMap<>(); Map<Pair<State, A>, Double> returnSum = new HashMap<>();
Map<Pair<State, A>, Integer> returnCount = new HashMap<>(); Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
State startingState = environment.reset();
for(int i = 0; i < nrOfEpisodes; ++i) { for(int i = 0; i < nrOfEpisodes; ++i) {
List<StepResult<A>> episode = new ArrayList<>(); List<StepResult<A>> episode = new ArrayList<>();
State state = environment.reset(); State state = environment.reset();
double rewardSum = 0; double sumOfRewards = 0;
for(int j=0; j < 10; ++j){ for(int j=0; j < 10; ++j){
Map<A, Double> actionValues = stateActionTable.getActionValues(state); Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues); A chosenAction = policy.chooseAction(actionValues);
StepResultEnvironment envResult = environment.step(chosenAction); StepResultEnvironment envResult = environment.step(chosenAction);
State nextState = envResult.getState(); State nextState = envResult.getState();
rewardSum += envResult.getReward(); sumOfRewards += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
if(envResult.isDone()) break; if(envResult.isDone()) break;
@ -57,13 +56,14 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
state = nextState; state = nextState;
try { try {
Thread.sleep(1); Thread.sleep(delay);
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); e.printStackTrace();
} }
} }
System.out.printf("Episode %d \t Reward: %f \n", i, rewardSum); dispatchEpisodeEnd(sumOfRewards);
System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new HashSet<>(); Set<Pair<State, A>> stateActionPairs = new HashSet<>();
for(StepResult<A> sr: episode){ for(StepResult<A> sr: episode){

View File

@ -0,0 +1,5 @@
package core.algo;
public enum Method {
MC_ONPOLICY_EGREEDY, TD_ONPOLICY
}

View File

@ -0,0 +1,81 @@
package core.controller;
import core.DiscreteActionSpace;
import core.Environment;
import core.ListDiscreteActionSpace;
import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.View;
import javax.swing.*;
import java.util.Optional;
public class RLController<A extends Enum> implements ViewListener{
protected Environment<A> environment;
protected Learning<A> learning;
protected DiscreteActionSpace<A> discreteActionSpace;
protected View<A> view;
private int delay;
private int nrOfEpisodes;
private Method method;
public RLController(){
}
public void start(){
if(environment == null || discreteActionSpace == null || method == null){
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
}
switch (method){
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
break;
case TD_ONPOLICY:
break;
default:
throw new RuntimeException("Undefined method");
}
SwingUtilities.invokeLater(() ->{
view = new View<>(learning, this);
learning.addListener(view);
});
learning.learn(nrOfEpisodes);
}
@Override
public void onEpsilonChange(float epsilon) {
learning.setEpsilon(epsilon);
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
}
@Override
public void onDelayChange(int delay) {
}
public RLController<A> setMethod(Method method){
this.method = method;
return this;
}
public RLController<A> setEnvironment(Environment<A> environment){
this.environment = environment;
return this;
}
@SafeVarargs
public final RLController<A> setAllowedActions(A... actions){
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
return this;
}
public RLController<A> setDelay(int delay){
this.delay = delay;
return this;
}
public RLController<A> setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes;
return this;
}
}

View File

@ -0,0 +1,6 @@
package core.controller;
public interface ViewListener {
void onEpsilonChange(float epsilon);
void onDelayChange(int delay);
}

View File

@ -0,0 +1,41 @@
package core.gui;
import core.algo.Learning;
import core.controller.ViewListener;
import javax.swing.*;
public class LearningInfoPanel extends JPanel {
private Learning learning;
private JLabel policyLabel;
private JLabel discountLabel;
private JLabel epsilonLabel;
private JSlider epsilonSlider;
private JSlider delaySlider;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
this.learning = learning;
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
policyLabel = new JLabel();
discountLabel = new JLabel();
epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(policyLabel);
add(discountLabel);
add(epsilonLabel);
add(epsilonSlider);
refreshLabels();
setVisible(true);
}
public void refreshLabels(){
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
}
protected JSlider getEpsilonSlider(){
return epsilonSlider;
}
}

View File

@ -0,0 +1,102 @@
package core.gui;
import core.algo.Learning;
import core.controller.ViewListener;
import core.listener.LearningListener;
import lombok.Getter;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XChartPanel;
import org.knowm.xchart.XYChart;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
public class View<A extends Enum> implements LearningListener {
private Learning<A> learning;
@Getter
private XYChart chart;
@Getter
private LearningInfoPanel learningInfoPanel;
@Getter
private JFrame mainFrame;
private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener;
private List<Double> rewardHistory;
public View(Learning<A> learning, ViewListener viewListener){
this.learning = learning;
this.viewListener = viewListener;
rewardHistory = new ArrayList<>();
this.initMainFrame();
}
private void initMainFrame(){
mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setLayout(new BorderLayout());
initLearningInfoPanel();
initRewardChart();
mainFrame.add(BorderLayout.WEST, learningInfoPanel);
mainFrame.add(BorderLayout.CENTER, rewardChartPanel);
mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
mainFrame.pack();
mainFrame.setVisible(true);
}
private void initLearningInfoPanel(){
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
}
private void initRewardChart(){
chart =
QuickChart.getChart(
"Rewards per Episode",
"Episode",
"Reward",
"randomWalk",
new double[] {0},
new double[] {0});
chart.getStyler().setLegendVisible(true);
chart.getStyler().setXAxisTicksVisible(true);
rewardChartPanel = new XChartPanel<>(chart);
rewardChartPanel.setPreferredSize(new Dimension(300,300));
}
public void showState(Visualizable state){
new JFrame(){
{
JComponent stateComponent = state.visualize();
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
add(stateComponent);
setVisible(true);
}
};
}
public void updateRewardGraph(double recentReward){
rewardHistory.add(recentReward);
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
rewardChartPanel.revalidate();
rewardChartPanel.repaint();
}
public void updateLearningInfoPanel(){
this.learningInfoPanel.refreshLabels();
}
@Override
public void onEpisodeEnd(double sumOfRewards) {
SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
}
@Override
public void onEpisodeStart() {
}
}

View File

@ -0,0 +1,7 @@
package core.gui;
import javax.swing.*;
public interface Visualizable {
JComponent visualize();
}

View File

@ -0,0 +1,6 @@
package core.listener;
public interface LearningListener{
void onEpisodeEnd(double sumOfRewards);
void onEpisodeStart();
}

View File

@ -1,7 +1,10 @@
package evironment.antGame; package evironment.antGame;
import core.State; import core.State;
import core.gui.Visualizable;
import evironment.antGame.gui.CellColor;
import javax.swing.*;
import java.awt.*; import java.awt.*;
import java.util.Arrays; import java.util.Arrays;
@ -10,7 +13,7 @@ import java.util.Arrays;
* Essentially a snapshot of the current Ant Agent * Essentially a snapshot of the current Ant Agent
* and therefor has to be deep copied * and therefor has to be deep copied
*/ */
public class AntState implements State { public class AntState implements State, Visualizable {
private final Cell[][] knownWorld; private final Cell[][] knownWorld;
private final Point pos; private final Point pos;
private final boolean hasFood; private final boolean hasFood;
@ -29,12 +32,12 @@ public class AntState implements State {
int unknown = 0; int unknown = 0;
int diff = 0; int diff = 0;
for (int i = 0; i < knownWorld.length; i++) { for (Cell[] cells : knownWorld) {
for (int j = 0; j < knownWorld[i].length; j++) { for (Cell cell : cells) {
if(knownWorld[i][j].getType() == CellType.UNKNOWN){ if (cell.getType() == CellType.UNKNOWN) {
unknown += 1; unknown += 1;
}else{ } else {
diff +=1; diff += 1;
} }
} }
} }
@ -62,7 +65,7 @@ public class AntState implements State {
@Override @Override
public String toString(){ public String toString(){
return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld)); return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld));
} }
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers //TODO: make this a utility function to generate hash Code based upon 2 prime numbers
@Override @Override
@ -89,4 +92,62 @@ public class AntState implements State {
} }
return super.equals(obj); return super.equals(obj);
} }
@Override
public JComponent visualize() {
return new JScrollPane() {
private int cellSize;
private final int paneWidth = 500;
private final int paneHeight = 500;
private Font font;
{
setPreferredSize(new Dimension(paneWidth, paneHeight));
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize);
JPanel worldPanel = new JPanel(){
{
setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
setVisible(true);
addMouseWheelListener(e -> {
if(e.getWheelRotation() > 0){
cellSize -= 1;
}else {
cellSize += 1;
}
font = new Font("plain", Font.BOLD, cellSize);
setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
revalidate();
repaint();
});
}
@Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
for (int i = 0; i < knownWorld.length; i++) {
for (int j = 0; j < knownWorld[0].length; j++) {
g.setColor(Color.BLACK);
g.drawRect(i*cellSize, j*cellSize, cellSize, cellSize);
g.setColor(CellColor.map.get(knownWorld[i][j].getType()));
if(knownWorld[i][j].getFood() > 0){
g.setColor(Color.YELLOW);
}
g.fillRect(i*cellSize+1, j*cellSize+1, cellSize -1, cellSize-1);
}
}
if(hasFocus()){
g.setColor(Color.RED);
}else {
g.setColor(Color.BLACK);
}
g.setFont(font);
g.drawString("A", pos.x * cellSize, (pos.y + 1) * cellSize);
}
};
getViewport().add(worldPanel);
setVisible(true);
}
};
}
} }

View File

@ -35,21 +35,15 @@ public class AntWorld implements Environment<AntAction>{
private int tick; private int tick;
private int maxEpisodeTicks; private int maxEpisodeTicks;
MainFrame gui;
public AntWorld(int width, int height, double foodDensity){ public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity); grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height); antAgent = new AntAgent(width, height);
myAnt = new Ant(); myAnt = new Ant();
gui = new MainFrame(this, antAgent);
maxEpisodeTicks = 1000; maxEpisodeTicks = 1000;
reset(); reset();
} }
public MainFrame getGui(){
return gui;
}
public AntWorld(){ public AntWorld(){
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY); this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
} }
@ -166,7 +160,6 @@ public class AntWorld implements Environment<AntAction>{
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info); StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
getGui().update(action, result);
return result; return result;
} }
@ -216,6 +209,6 @@ public class AntWorld implements Environment<AntAction>{
new AntWorld(3, 3, 0.1), new AntWorld(3, 3, 0.1),
new ListDiscreteActionSpace<>(AntAction.values()) new ListDiscreteActionSpace<>(AntAction.values())
); );
monteCarlo.learn(20000,5); monteCarlo.learn(20000);
} }
} }

View File

@ -0,0 +1,21 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(1234);
RLController<AntAction> rl = new RLController<AntAction>()
.setEnvironment(new AntWorld(3,3,0.1))
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(10)
.setEpisodes(1000);
rl.start();
}
}

View File

@ -0,0 +1,52 @@
package example;
public class Test {
interface Drawable{
void draw();
}
interface State{
int getInt();
}
static class A implements Drawable, State{
private int k;
public A(int a){
k = a;
}
@Override
public void draw() {
System.out.println("draw " + k);
}
@Override
public int getInt() {
System.out.println("getInt" + k);
return k;
}
}
static class B implements State{
@Override
public int getInt() {
return 0;
}
}
public static void main(String[] args) {
State state = new A(24);
State state2 = new B();
state.getInt();
System.out.println(state2 instanceof Drawable);
drawState(state2);
}
static void drawState(State s){
if(s instanceof Drawable){
Drawable d = (Drawable) s;
d.draw();
}else{
System.out.println("invalid");
}
}
}