distinguish learning and episodic learning, enable fast-learning without drawing every step to reduce lag

- repainting every step on no time delay will certainly freeze the app, so "fast-learning" will disable it, only refreshing current episode label
- Added new abstract class "Episodic Learning". Maybe just use an interface instead?! Important because TD learning is not episodic, needs another way to represent the rewards received (maybe mean of last X rewards or sth)
- Opening two JFrames, one with learning infos and one with environment
This commit is contained in:
Jan Löwenstrom 2019-12-21 00:23:09 +01:00
parent 7db5a2af3b
commit 34e7e3fdd6
14 changed files with 188 additions and 75 deletions

View File

@ -13,10 +13,6 @@ public class RNG {
return rng;
}
public static void reseed(){
rng.setSeed(seed);
}
public static void setSeed(int seed){
RNG.seed = seed;
rng.setSeed(seed);

View File

@ -0,0 +1,5 @@
package core.algo;
public interface Episodic {
int getCurrentEpisode();
}

View File

@ -0,0 +1,29 @@
package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{
protected int currentEpisode;
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
}
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
super(environment, actionSpace, discountFactor);
}
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
super(environment, actionSpace, delay);
}
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
super(environment, actionSpace);
}
@Override
public int getCurrentEpisode(){
return currentEpisode;
}
}

View File

@ -9,8 +9,6 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import javax.swing.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -68,4 +66,10 @@ public abstract class Learning<A extends Enum> {
l.onEpisodeStart();
}
}
protected void dispatchStepEnd(){
for(LearningListener l: learningListeners){
l.onStepEnd();
}
}
}

View File

@ -1,10 +1,9 @@
package core.algo.mc;
import core.*;
import core.algo.Learning;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
import lombok.Setter;
import java.util.*;
@ -26,11 +25,11 @@ import java.util.*;
* How to encounter this problem?
* @param <A>
*/
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
currentEpisode = 0;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
}
@ -47,8 +46,15 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
for(int i = 0; i < nrOfEpisodes; ++i) {
++currentEpisode;
List<StepResult<A>> episode = new ArrayList<>();
State state = environment.reset();
dispatchEpisodeStart();
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
double sumOfRewards = 0;
for(int j=0; j < 10; ++j){
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
@ -67,6 +73,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
} catch (InterruptedException e) {
e.printStackTrace();
}
dispatchStepEnd();
}
dispatchEpisodeEnd(sumOfRewards);
@ -100,4 +107,9 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
}
}
}
@Override
public int getCurrentEpisode() {
return currentEpisode;
}
}

View File

@ -7,11 +7,12 @@ import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.View;
import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import javax.swing.*;
public class RLController<A extends Enum> implements ViewListener{
public class RLController<A extends Enum> implements ViewListener {
protected Environment<A> environment;
protected Learning<A> learning;
protected DiscreteActionSpace<A> discreteActionSpace;
@ -19,6 +20,7 @@ public class RLController<A extends Enum> implements ViewListener{
private int delay;
private int nrOfEpisodes;
private Method method;
private int prevDelay;
public RLController(){
}
@ -41,7 +43,7 @@ public class RLController<A extends Enum> implements ViewListener{
not using SwingUtilities here on purpose to ensure the view is fully
initialized and can be passed as LearningListener.
*/
view = new View<>(learning, this);
view = new View<>(learning, environment, this);
learning.addListener(view);
learning.learn(nrOfEpisodes);
}
@ -58,10 +60,23 @@ public class RLController<A extends Enum> implements ViewListener{
@Override
public void onDelayChange(int delay) {
changeLearningDelay(delay);
}
private void changeLearningDelay(int delay){
learning.setDelay(delay);
SwingUtilities.invokeLater(() -> {
view.updateLearningInfoPanel();
});
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
}
@Override
public void onFastLearnChange(boolean fastLearn) {
view.setDrawEveryStep(!fastLearn);
if(fastLearn){
prevDelay = learning.getDelay();
changeLearningDelay(0);
}else{
changeLearningDelay(prevDelay);
}
}
public RLController<A> setMethod(Method method){

View File

@ -1,7 +1,8 @@
package core.gui;
import core.algo.Episodic;
import core.algo.Learning;
import core.controller.ViewListener;
import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import javax.swing.*;
@ -11,39 +12,62 @@ public class LearningInfoPanel extends JPanel {
private JLabel policyLabel;
private JLabel discountLabel;
private JLabel epsilonLabel;
private JLabel episodeLabel;
private JSlider epsilonSlider;
private JLabel delayLabel;
private JSlider delaySlider;
private JButton toggleFastLearningButton;
private boolean fastLearning;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
this.learning = learning;
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
policyLabel = new JLabel();
discountLabel = new JLabel();
epsilonLabel = new JLabel();
delayLabel = new JLabel();
if(learning instanceof Episodic){
episodeLabel = new JLabel();
add(episodeLabel);
}
delaySlider = new JSlider(0,1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
if(learning.getPolicy() instanceof EpsilonPolicy){
epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(epsilonLabel);
add(epsilonSlider);
}
toggleFastLearningButton = new JButton("Enable fast-learn");
fastLearning = false;
toggleFastLearningButton.addActionListener(e->{
fastLearning = !fastLearning;
delaySlider.setEnabled(!fastLearning);
epsilonSlider.setEnabled(!fastLearning);
viewListener.onFastLearnChange(fastLearning);
});
add(delayLabel);
add(delaySlider);
add(toggleFastLearningButton);
refreshLabels();
setVisible(true);
}
public void refreshLabels(){
public void refreshLabels() {
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
if(learning.getPolicy() instanceof EpsilonPolicy){
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon());
if(learning instanceof Episodic){
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode());
}
if (learning.getPolicy() instanceof EpsilonPolicy) {
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
}
delayLabel.setText("Delay (ms): " + learning.getDelay());
delaySlider.setValue(learning.getDelay());
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
}
}

View File

@ -1,7 +1,8 @@
package core.gui;
import core.Environment;
import core.algo.Learning;
import core.controller.ViewListener;
import core.listener.ViewListener;
import core.listener.LearningListener;
import lombok.Getter;
import org.knowm.xchart.QuickChart;
@ -10,27 +11,31 @@ import org.knowm.xchart.XYChart;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
public class View<A extends Enum> implements LearningListener {
private Learning<A> learning;
private Environment<A> environment;
@Getter
private XYChart chart;
private XYChart rewardChart;
@Getter
private LearningInfoPanel learningInfoPanel;
@Getter
private JFrame mainFrame;
private JFrame environmentFrame;
private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener;
private boolean drawEveryStep;
public View(Learning<A> learning, ViewListener viewListener){
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
this.learning = learning;
this.environment = environment;
this.viewListener = viewListener;
this.initMainFrame();
drawEveryStep = true;
SwingUtilities.invokeLater(this::initMainFrame);
}
private void initMainFrame(){
private void initMainFrame() {
mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setLayout(new BorderLayout());
@ -44,29 +49,40 @@ public class View<A extends Enum> implements LearningListener {
mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
mainFrame.pack();
mainFrame.setVisible(true);
if (environment instanceof Visualizable) {
environmentFrame = new JFrame() {
{
add(((Visualizable) environment).visualize());
pack();
setVisible(true);
}
};
}
}
private void initLearningInfoPanel(){
private void initLearningInfoPanel() {
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
}
private void initRewardChart(){
chart =
private void initRewardChart() {
rewardChart =
QuickChart.getChart(
"Rewards per Episode",
"Sum of Rewards per Episode",
"Episode",
"Reward",
"randomWalk",
new double[] {0},
new double[] {0});
chart.getStyler().setLegendVisible(true);
chart.getStyler().setXAxisTicksVisible(true);
rewardChartPanel = new XChartPanel<>(chart);
rewardChartPanel.setPreferredSize(new Dimension(300,300));
"rewardHistory",
new double[]{0},
new double[]{0});
rewardChart.getStyler().setLegendVisible(true);
rewardChart.getStyler().setXAxisTicksVisible(true);
rewardChartPanel = new XChartPanel<>(rewardChart);
rewardChartPanel.setPreferredSize(new Dimension(300, 300));
}
public void showState(Visualizable state){
new JFrame(){
public void showState(Visualizable state) {
new JFrame() {
{
JComponent stateComponent = state.visualize();
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
@ -76,25 +92,47 @@ public class View<A extends Enum> implements LearningListener {
};
}
public void updateRewardGraph(List<Double> rewardHistory){
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
public void setDrawEveryStep(boolean drawEveryStep){
this.drawEveryStep = drawEveryStep;
}
public void updateRewardGraph(List<Double> rewardHistory) {
rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null);
rewardChartPanel.revalidate();
rewardChartPanel.repaint();
}
public void updateLearningInfoPanel(){
public void updateLearningInfoPanel() {
this.learningInfoPanel.refreshLabels();
}
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
SwingUtilities.invokeLater(()->{
updateRewardGraph(rewardHistory);
});
SwingUtilities.invokeLater(() ->{
if(drawEveryStep){
updateRewardGraph(rewardHistory);
}
updateLearningInfoPanel();
});
}
@Override
public void onEpisodeStart() {
if(drawEveryStep) {
SwingUtilities.invokeLater(this::repaintEnvironment);
}
}
@Override
public void onStepEnd() {
if(drawEveryStep){
SwingUtilities.invokeLater(this::repaintEnvironment);
}
}
private void repaintEnvironment(){
if (environmentFrame != null) {
environmentFrame.repaint();
}
}
}

View File

@ -5,4 +5,5 @@ import java.util.List;
public interface LearningListener{
void onEpisodeEnd(List<Double> rewardHistory);
void onEpisodeStart();
void onStepEnd();
}

View File

@ -1,6 +1,7 @@
package core.controller;
package core.listener;
public interface ViewListener {
void onEpsilonChange(float epsilon);
void onDelayChange(int delay);
void onFastLearnChange(boolean isFastLearn);
}

View File

@ -29,7 +29,6 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
@Override
public A chooseAction(Map<A, Double> actionValues) {
System.out.println("current epsilon " + epsilon);
if(RNG.getRandom().nextFloat() < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);

View File

@ -1,14 +1,15 @@
package evironment.antGame;
import core.*;
import core.algo.Learning;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import evironment.antGame.gui.MainFrame;
import core.Environment;
import core.State;
import core.StepResultEnvironment;
import core.gui.Visualizable;
import evironment.antGame.gui.AntWorldComponent;
import javax.swing.*;
import java.awt.*;
public class AntWorld implements Environment<AntAction>{
public class AntWorld implements Environment<AntAction>, Visualizable {
/**
*
*/
@ -204,14 +205,8 @@ public class AntWorld implements Environment<AntAction>{
return myAnt;
}
public static void main(String[] args) {
RNG.setSeed(1993);
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
new AntWorld(3, 3, 0.1),
new ListDiscreteActionSpace<>(AntAction.values()),
5
);
monteCarlo.learn(20000);
@Override
public JComponent visualize() {
return new AntWorldComponent(this, this.antAgent);
}
}

View File

@ -1,21 +1,18 @@
package evironment.antGame.gui;
import core.StepResultEnvironment;
import evironment.antGame.AntAction;
import evironment.antGame.AntAgent;
import evironment.antGame.AntWorld;
import javax.swing.*;
import java.awt.*;
public class MainFrame extends JFrame {
public class AntWorldComponent extends JComponent {
private AntWorld antWorld;
private HistoryPanel historyPanel;
public MainFrame(AntWorld antWorld, AntAgent antAgent){
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
this.antWorld = antWorld;
setLayout(new BorderLayout());
setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
historyPanel = new HistoryPanel();
@ -29,14 +26,11 @@ public class MainFrame extends JFrame {
add(BorderLayout.CENTER, mapComponent);
add(BorderLayout.SOUTH, historyPanel);
pack();
setVisible(true);
}
public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
repaint();
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
}
}

View File

@ -8,14 +8,14 @@ import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(1234);
RNG.setSeed(123);
RLController<AntAction> rl = new RLController<AntAction>()
.setEnvironment(new AntWorld(3,3,0.1))
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(10)
.setEpisodes(10000);
.setEpisodes(100000);
rl.start();
}
}