adopt MVC pattern and add real time graph interface

This commit is contained in:
Jan Löwenstrom 2019-12-18 16:48:24 +01:00
parent 7f18a66e98
commit e0160ca1df
18 changed files with 450 additions and 29 deletions

View File

@ -13,6 +13,9 @@ repositories {
}
dependencies {
// https://mvnrepository.com/artifact/org.jfree/jfreechart
// https://mvnrepository.com/artifact/org.knowm.xchart/xchart
compile group: 'org.knowm.xchart', name: 'xchart', version: '3.2.2'
testCompile group: 'junit', name: 'junit', version: '4.12'
compileOnly 'org.projectlombok:lombok:1.18.10'
annotationProcessor 'org.projectlombok:lombok:1.18.10'

View File

@ -3,7 +3,7 @@ package core;
import java.util.List;
public interface DiscreteActionSpace<A extends Enum> {
int getNumberOfAction();
int getNumberOfActions();
void addAction(A a);
void addActions(A... as);
List<A> getAllActions();

View File

@ -0,0 +1,7 @@
package core;
public class LearningConfig {
public static final int DEFAULT_DELAY = 1;
public static final float DEFAULT_EPSILON = 0.1f;
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
}

View File

@ -32,7 +32,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
}
@Override
public int getNumberOfAction(){
public int getNumberOfActions(){
return actions.size();
}
}

View File

@ -2,26 +2,62 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.StateActionTable;
import core.listener.LearningListener;
import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import javax.swing.*;
import java.util.HashSet;
import java.util.Set;
@Getter
public abstract class Learning<A extends Enum> {
protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace;
protected StateActionTable<A> stateActionTable;
protected Environment<A> environment;
protected float discountFactor;
@Setter
protected float epsilon;
protected Set<LearningListener> learningListeners;
@Setter
protected int delay;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay){
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.epsilon = epsilon;
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
this(environment, actionSpace, 1.0f, 0.1f);
this.delay = delay;
learningListeners = new HashSet<>();
}
public abstract void learn(int nrOfEpisodes, int delay);
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
}
public abstract void learn(int nrOfEpisodes);
public void addListener(LearningListener learningListener){
learningListeners.add(learningListener);
}
protected void dispatchEpisodeEnd(double sum){
for(LearningListener l: learningListeners) {
l.onEpisodeEnd(sum);
}
}
protected void dispatchEpisodeStart(){
for(LearningListener l: learningListeners){
l.onEpisodeStart();
}
}
}

View File

@ -34,22 +34,21 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
}
@Override
public void learn(int nrOfEpisodes, int delay) {
public void learn(int nrOfEpisodes) {
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
State startingState = environment.reset();
for(int i = 0; i < nrOfEpisodes; ++i) {
List<StepResult<A>> episode = new ArrayList<>();
State state = environment.reset();
double rewardSum = 0;
double sumOfRewards = 0;
for(int j=0; j < 10; ++j){
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
StepResultEnvironment envResult = environment.step(chosenAction);
State nextState = envResult.getState();
rewardSum += envResult.getReward();
sumOfRewards += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
if(envResult.isDone()) break;
@ -57,13 +56,14 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
state = nextState;
try {
Thread.sleep(1);
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
System.out.printf("Episode %d \t Reward: %f \n", i, rewardSum);
dispatchEpisodeEnd(sumOfRewards);
System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
for(StepResult<A> sr: episode){

View File

@ -0,0 +1,5 @@
package core.algo;
public enum Method {
MC_ONPOLICY_EGREEDY, TD_ONPOLICY
}

View File

@ -0,0 +1,81 @@
package core.controller;
import core.DiscreteActionSpace;
import core.Environment;
import core.ListDiscreteActionSpace;
import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.View;
import javax.swing.*;
import java.util.Optional;
public class RLController<A extends Enum> implements ViewListener{
protected Environment<A> environment;
protected Learning<A> learning;
protected DiscreteActionSpace<A> discreteActionSpace;
protected View<A> view;
private int delay;
private int nrOfEpisodes;
private Method method;
public RLController(){
}
public void start(){
if(environment == null || discreteActionSpace == null || method == null){
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
}
switch (method){
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
break;
case TD_ONPOLICY:
break;
default:
throw new RuntimeException("Undefined method");
}
SwingUtilities.invokeLater(() ->{
view = new View<>(learning, this);
learning.addListener(view);
});
learning.learn(nrOfEpisodes);
}
@Override
public void onEpsilonChange(float epsilon) {
learning.setEpsilon(epsilon);
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
}
@Override
public void onDelayChange(int delay) {
}
public RLController<A> setMethod(Method method){
this.method = method;
return this;
}
public RLController<A> setEnvironment(Environment<A> environment){
this.environment = environment;
return this;
}
@SafeVarargs
public final RLController<A> setAllowedActions(A... actions){
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
return this;
}
public RLController<A> setDelay(int delay){
this.delay = delay;
return this;
}
public RLController<A> setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes;
return this;
}
}

View File

@ -0,0 +1,6 @@
package core.controller;
public interface ViewListener {
void onEpsilonChange(float epsilon);
void onDelayChange(int delay);
}

View File

@ -0,0 +1,41 @@
package core.gui;
import core.algo.Learning;
import core.controller.ViewListener;
import javax.swing.*;
public class LearningInfoPanel extends JPanel {
private Learning learning;
private JLabel policyLabel;
private JLabel discountLabel;
private JLabel epsilonLabel;
private JSlider epsilonSlider;
private JSlider delaySlider;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
this.learning = learning;
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
policyLabel = new JLabel();
discountLabel = new JLabel();
epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(policyLabel);
add(discountLabel);
add(epsilonLabel);
add(epsilonSlider);
refreshLabels();
setVisible(true);
}
public void refreshLabels(){
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
}
protected JSlider getEpsilonSlider(){
return epsilonSlider;
}
}

View File

@ -0,0 +1,102 @@
package core.gui;
import core.algo.Learning;
import core.controller.ViewListener;
import core.listener.LearningListener;
import lombok.Getter;
import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XChartPanel;
import org.knowm.xchart.XYChart;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
public class View<A extends Enum> implements LearningListener {
private Learning<A> learning;
@Getter
private XYChart chart;
@Getter
private LearningInfoPanel learningInfoPanel;
@Getter
private JFrame mainFrame;
private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener;
private List<Double> rewardHistory;
public View(Learning<A> learning, ViewListener viewListener){
this.learning = learning;
this.viewListener = viewListener;
rewardHistory = new ArrayList<>();
this.initMainFrame();
}
private void initMainFrame(){
mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setLayout(new BorderLayout());
initLearningInfoPanel();
initRewardChart();
mainFrame.add(BorderLayout.WEST, learningInfoPanel);
mainFrame.add(BorderLayout.CENTER, rewardChartPanel);
mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
mainFrame.pack();
mainFrame.setVisible(true);
}
private void initLearningInfoPanel(){
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
}
private void initRewardChart(){
chart =
QuickChart.getChart(
"Rewards per Episode",
"Episode",
"Reward",
"randomWalk",
new double[] {0},
new double[] {0});
chart.getStyler().setLegendVisible(true);
chart.getStyler().setXAxisTicksVisible(true);
rewardChartPanel = new XChartPanel<>(chart);
rewardChartPanel.setPreferredSize(new Dimension(300,300));
}
public void showState(Visualizable state){
new JFrame(){
{
JComponent stateComponent = state.visualize();
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
add(stateComponent);
setVisible(true);
}
};
}
public void updateRewardGraph(double recentReward){
rewardHistory.add(recentReward);
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
rewardChartPanel.revalidate();
rewardChartPanel.repaint();
}
public void updateLearningInfoPanel(){
this.learningInfoPanel.refreshLabels();
}
@Override
public void onEpisodeEnd(double sumOfRewards) {
SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
}
@Override
public void onEpisodeStart() {
}
}

View File

@ -0,0 +1,7 @@
package core.gui;
import javax.swing.*;
public interface Visualizable {
JComponent visualize();
}

View File

@ -0,0 +1,6 @@
package core.listener;
public interface LearningListener{
void onEpisodeEnd(double sumOfRewards);
void onEpisodeStart();
}

View File

@ -1,7 +1,10 @@
package evironment.antGame;
import core.State;
import core.gui.Visualizable;
import evironment.antGame.gui.CellColor;
import javax.swing.*;
import java.awt.*;
import java.util.Arrays;
@ -10,7 +13,7 @@ import java.util.Arrays;
* Essentially a snapshot of the current Ant Agent
* and therefor has to be deep copied
*/
public class AntState implements State {
public class AntState implements State, Visualizable {
private final Cell[][] knownWorld;
private final Point pos;
private final boolean hasFood;
@ -29,9 +32,9 @@ public class AntState implements State {
int unknown = 0;
int diff = 0;
for (int i = 0; i < knownWorld.length; i++) {
for (int j = 0; j < knownWorld[i].length; j++) {
if(knownWorld[i][j].getType() == CellType.UNKNOWN){
for (Cell[] cells : knownWorld) {
for (Cell cell : cells) {
if (cell.getType() == CellType.UNKNOWN) {
unknown += 1;
} else {
diff += 1;
@ -89,4 +92,62 @@ public class AntState implements State {
}
return super.equals(obj);
}
@Override
public JComponent visualize() {
return new JScrollPane() {
private int cellSize;
private final int paneWidth = 500;
private final int paneHeight = 500;
private Font font;
{
setPreferredSize(new Dimension(paneWidth, paneHeight));
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize);
JPanel worldPanel = new JPanel(){
{
setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
setVisible(true);
addMouseWheelListener(e -> {
if(e.getWheelRotation() > 0){
cellSize -= 1;
}else {
cellSize += 1;
}
font = new Font("plain", Font.BOLD, cellSize);
setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
revalidate();
repaint();
});
}
@Override
public void paintComponent(Graphics g) {
super.paintComponent(g);
for (int i = 0; i < knownWorld.length; i++) {
for (int j = 0; j < knownWorld[0].length; j++) {
g.setColor(Color.BLACK);
g.drawRect(i*cellSize, j*cellSize, cellSize, cellSize);
g.setColor(CellColor.map.get(knownWorld[i][j].getType()));
if(knownWorld[i][j].getFood() > 0){
g.setColor(Color.YELLOW);
}
g.fillRect(i*cellSize+1, j*cellSize+1, cellSize -1, cellSize-1);
}
}
if(hasFocus()){
g.setColor(Color.RED);
}else {
g.setColor(Color.BLACK);
}
g.setFont(font);
g.drawString("A", pos.x * cellSize, (pos.y + 1) * cellSize);
}
};
getViewport().add(worldPanel);
setVisible(true);
}
};
}
}

View File

@ -35,21 +35,15 @@ public class AntWorld implements Environment<AntAction>{
private int tick;
private int maxEpisodeTicks;
MainFrame gui;
public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height);
myAnt = new Ant();
gui = new MainFrame(this, antAgent);
maxEpisodeTicks = 1000;
reset();
}
public MainFrame getGui(){
return gui;
}
public AntWorld(){
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
}
@ -166,7 +160,6 @@ public class AntWorld implements Environment<AntAction>{
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
getGui().update(action, result);
return result;
}
@ -216,6 +209,6 @@ public class AntWorld implements Environment<AntAction>{
new AntWorld(3, 3, 0.1),
new ListDiscreteActionSpace<>(AntAction.values())
);
monteCarlo.learn(20000,5);
monteCarlo.learn(20000);
}
}

View File

@ -0,0 +1,21 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(1234);
RLController<AntAction> rl = new RLController<AntAction>()
.setEnvironment(new AntWorld(3,3,0.1))
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(10)
.setEpisodes(1000);
rl.start();
}
}

View File

@ -0,0 +1,52 @@
package example;
public class Test {
interface Drawable{
void draw();
}
interface State{
int getInt();
}
static class A implements Drawable, State{
private int k;
public A(int a){
k = a;
}
@Override
public void draw() {
System.out.println("draw " + k);
}
@Override
public int getInt() {
System.out.println("getInt" + k);
return k;
}
}
static class B implements State{
@Override
public int getInt() {
return 0;
}
}
public static void main(String[] args) {
State state = new A(24);
State state2 = new B();
state.getInt();
System.out.println(state2 instanceof Drawable);
drawState(state2);
}
static void drawState(State s){
if(s instanceof Drawable){
Drawable d = (Drawable) s;
d.draw();
}else{
System.out.println("invalid");
}
}
}