distinguish learning and episodic learning, enable fast-learning without drawing every step to reduce lag
- repainting every step on no time delay will certainly freeze the app, so "fast-learning" will disable it, only refreshing current episode label - Added new abstract class "Episodic Learning". Maybe just use an interface instead?! Important because TD learning is not episodic, needs another way to represent the rewards received (maybe mean of last X rewards or sth) - Opening two JFrames, one with learning infos and one with environment
This commit is contained in:
parent
7db5a2af3b
commit
34e7e3fdd6
|
@ -13,10 +13,6 @@ public class RNG {
|
|||
return rng;
|
||||
}
|
||||
|
||||
public static void reseed(){
|
||||
rng.setSeed(seed);
|
||||
}
|
||||
|
||||
public static void setSeed(int seed){
|
||||
RNG.seed = seed;
|
||||
rng.setSeed(seed);
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
package core.algo;
|
||||
|
||||
public interface Episodic {
|
||||
int getCurrentEpisode();
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package core.algo;
|
||||
|
||||
import core.DiscreteActionSpace;
|
||||
import core.Environment;
|
||||
|
||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{
|
||||
protected int currentEpisode;
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
}
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||
super(environment, actionSpace, discountFactor);
|
||||
}
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
super(environment, actionSpace, delay);
|
||||
}
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
||||
super(environment, actionSpace);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCurrentEpisode(){
|
||||
return currentEpisode;
|
||||
}
|
||||
}
|
|
@ -9,8 +9,6 @@ import core.policy.Policy;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
@ -68,4 +66,10 @@ public abstract class Learning<A extends Enum> {
|
|||
l.onEpisodeStart();
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchStepEnd(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onStepEnd();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
package core.algo.mc;
|
||||
|
||||
import core.*;
|
||||
import core.algo.Learning;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import javafx.util.Pair;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -26,11 +25,11 @@ import java.util.*;
|
|||
* How to encounter this problem?
|
||||
* @param <A>
|
||||
*/
|
||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
|
||||
currentEpisode = 0;
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
|
||||
}
|
||||
|
@ -47,8 +46,15 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
|||
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
||||
|
||||
for(int i = 0; i < nrOfEpisodes; ++i) {
|
||||
++currentEpisode;
|
||||
List<StepResult<A>> episode = new ArrayList<>();
|
||||
State state = environment.reset();
|
||||
dispatchEpisodeStart();
|
||||
try {
|
||||
Thread.sleep(delay);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
double sumOfRewards = 0;
|
||||
for(int j=0; j < 10; ++j){
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
|
@ -67,6 +73,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
|||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
dispatchStepEnd();
|
||||
}
|
||||
|
||||
dispatchEpisodeEnd(sumOfRewards);
|
||||
|
@ -100,4 +107,9 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCurrentEpisode() {
|
||||
return currentEpisode;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,11 +7,12 @@ import core.algo.Learning;
|
|||
import core.algo.Method;
|
||||
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
||||
import core.gui.View;
|
||||
import core.listener.ViewListener;
|
||||
import core.policy.EpsilonPolicy;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener{
|
||||
public class RLController<A extends Enum> implements ViewListener {
|
||||
protected Environment<A> environment;
|
||||
protected Learning<A> learning;
|
||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||
|
@ -19,6 +20,7 @@ public class RLController<A extends Enum> implements ViewListener{
|
|||
private int delay;
|
||||
private int nrOfEpisodes;
|
||||
private Method method;
|
||||
private int prevDelay;
|
||||
|
||||
public RLController(){
|
||||
}
|
||||
|
@ -41,7 +43,7 @@ public class RLController<A extends Enum> implements ViewListener{
|
|||
not using SwingUtilities here on purpose to ensure the view is fully
|
||||
initialized and can be passed as LearningListener.
|
||||
*/
|
||||
view = new View<>(learning, this);
|
||||
view = new View<>(learning, environment, this);
|
||||
learning.addListener(view);
|
||||
learning.learn(nrOfEpisodes);
|
||||
}
|
||||
|
@ -58,10 +60,23 @@ public class RLController<A extends Enum> implements ViewListener{
|
|||
|
||||
@Override
|
||||
public void onDelayChange(int delay) {
|
||||
changeLearningDelay(delay);
|
||||
}
|
||||
|
||||
private void changeLearningDelay(int delay){
|
||||
learning.setDelay(delay);
|
||||
SwingUtilities.invokeLater(() -> {
|
||||
view.updateLearningInfoPanel();
|
||||
});
|
||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFastLearnChange(boolean fastLearn) {
|
||||
view.setDrawEveryStep(!fastLearn);
|
||||
if(fastLearn){
|
||||
prevDelay = learning.getDelay();
|
||||
changeLearningDelay(0);
|
||||
}else{
|
||||
changeLearningDelay(prevDelay);
|
||||
}
|
||||
}
|
||||
|
||||
public RLController<A> setMethod(Method method){
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
package core.gui;
|
||||
|
||||
import core.algo.Episodic;
|
||||
import core.algo.Learning;
|
||||
import core.controller.ViewListener;
|
||||
import core.listener.ViewListener;
|
||||
import core.policy.EpsilonPolicy;
|
||||
|
||||
import javax.swing.*;
|
||||
|
@ -11,39 +12,62 @@ public class LearningInfoPanel extends JPanel {
|
|||
private JLabel policyLabel;
|
||||
private JLabel discountLabel;
|
||||
private JLabel epsilonLabel;
|
||||
private JLabel episodeLabel;
|
||||
private JSlider epsilonSlider;
|
||||
private JLabel delayLabel;
|
||||
private JSlider delaySlider;
|
||||
private JButton toggleFastLearningButton;
|
||||
private boolean fastLearning;
|
||||
|
||||
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
||||
this.learning = learning;
|
||||
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
|
||||
policyLabel = new JLabel();
|
||||
discountLabel = new JLabel();
|
||||
epsilonLabel = new JLabel();
|
||||
delayLabel = new JLabel();
|
||||
if(learning instanceof Episodic){
|
||||
episodeLabel = new JLabel();
|
||||
add(episodeLabel);
|
||||
}
|
||||
delaySlider = new JSlider(0,1000, learning.getDelay());
|
||||
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
|
||||
add(policyLabel);
|
||||
add(discountLabel);
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
epsilonLabel = new JLabel();
|
||||
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
|
||||
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
|
||||
add(epsilonLabel);
|
||||
add(epsilonSlider);
|
||||
}
|
||||
|
||||
toggleFastLearningButton = new JButton("Enable fast-learn");
|
||||
fastLearning = false;
|
||||
toggleFastLearningButton.addActionListener(e->{
|
||||
fastLearning = !fastLearning;
|
||||
delaySlider.setEnabled(!fastLearning);
|
||||
epsilonSlider.setEnabled(!fastLearning);
|
||||
viewListener.onFastLearnChange(fastLearning);
|
||||
});
|
||||
add(delayLabel);
|
||||
add(delaySlider);
|
||||
add(toggleFastLearningButton);
|
||||
refreshLabels();
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
public void refreshLabels(){
|
||||
public void refreshLabels() {
|
||||
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
||||
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon());
|
||||
if(learning instanceof Episodic){
|
||||
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode());
|
||||
}
|
||||
if (learning.getPolicy() instanceof EpsilonPolicy) {
|
||||
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
|
||||
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
|
||||
}
|
||||
delayLabel.setText("Delay (ms): " + learning.getDelay());
|
||||
delaySlider.setValue(learning.getDelay());
|
||||
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
package core.gui;
|
||||
|
||||
import core.Environment;
|
||||
import core.algo.Learning;
|
||||
import core.controller.ViewListener;
|
||||
import core.listener.ViewListener;
|
||||
import core.listener.LearningListener;
|
||||
import lombok.Getter;
|
||||
import org.knowm.xchart.QuickChart;
|
||||
|
@ -10,27 +11,31 @@ import org.knowm.xchart.XYChart;
|
|||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class View<A extends Enum> implements LearningListener {
|
||||
private Learning<A> learning;
|
||||
private Environment<A> environment;
|
||||
@Getter
|
||||
private XYChart chart;
|
||||
private XYChart rewardChart;
|
||||
@Getter
|
||||
private LearningInfoPanel learningInfoPanel;
|
||||
@Getter
|
||||
private JFrame mainFrame;
|
||||
private JFrame environmentFrame;
|
||||
private XChartPanel<XYChart> rewardChartPanel;
|
||||
private ViewListener viewListener;
|
||||
private boolean drawEveryStep;
|
||||
|
||||
public View(Learning<A> learning, ViewListener viewListener){
|
||||
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
|
||||
this.learning = learning;
|
||||
this.environment = environment;
|
||||
this.viewListener = viewListener;
|
||||
this.initMainFrame();
|
||||
drawEveryStep = true;
|
||||
SwingUtilities.invokeLater(this::initMainFrame);
|
||||
}
|
||||
|
||||
private void initMainFrame(){
|
||||
private void initMainFrame() {
|
||||
mainFrame = new JFrame();
|
||||
mainFrame.setPreferredSize(new Dimension(1280, 720));
|
||||
mainFrame.setLayout(new BorderLayout());
|
||||
|
@ -44,29 +49,40 @@ public class View<A extends Enum> implements LearningListener {
|
|||
mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||
mainFrame.pack();
|
||||
mainFrame.setVisible(true);
|
||||
|
||||
if (environment instanceof Visualizable) {
|
||||
environmentFrame = new JFrame() {
|
||||
{
|
||||
add(((Visualizable) environment).visualize());
|
||||
pack();
|
||||
setVisible(true);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private void initLearningInfoPanel(){
|
||||
private void initLearningInfoPanel() {
|
||||
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
|
||||
}
|
||||
|
||||
private void initRewardChart(){
|
||||
chart =
|
||||
private void initRewardChart() {
|
||||
rewardChart =
|
||||
QuickChart.getChart(
|
||||
"Rewards per Episode",
|
||||
"Sum of Rewards per Episode",
|
||||
"Episode",
|
||||
"Reward",
|
||||
"randomWalk",
|
||||
new double[] {0},
|
||||
new double[] {0});
|
||||
chart.getStyler().setLegendVisible(true);
|
||||
chart.getStyler().setXAxisTicksVisible(true);
|
||||
rewardChartPanel = new XChartPanel<>(chart);
|
||||
rewardChartPanel.setPreferredSize(new Dimension(300,300));
|
||||
"rewardHistory",
|
||||
new double[]{0},
|
||||
new double[]{0});
|
||||
rewardChart.getStyler().setLegendVisible(true);
|
||||
rewardChart.getStyler().setXAxisTicksVisible(true);
|
||||
rewardChartPanel = new XChartPanel<>(rewardChart);
|
||||
rewardChartPanel.setPreferredSize(new Dimension(300, 300));
|
||||
}
|
||||
|
||||
public void showState(Visualizable state){
|
||||
new JFrame(){
|
||||
public void showState(Visualizable state) {
|
||||
new JFrame() {
|
||||
{
|
||||
JComponent stateComponent = state.visualize();
|
||||
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
|
||||
|
@ -76,25 +92,47 @@ public class View<A extends Enum> implements LearningListener {
|
|||
};
|
||||
}
|
||||
|
||||
public void updateRewardGraph(List<Double> rewardHistory){
|
||||
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
|
||||
public void setDrawEveryStep(boolean drawEveryStep){
|
||||
this.drawEveryStep = drawEveryStep;
|
||||
}
|
||||
|
||||
public void updateRewardGraph(List<Double> rewardHistory) {
|
||||
rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null);
|
||||
rewardChartPanel.revalidate();
|
||||
rewardChartPanel.repaint();
|
||||
}
|
||||
|
||||
public void updateLearningInfoPanel(){
|
||||
public void updateLearningInfoPanel() {
|
||||
this.learningInfoPanel.refreshLabels();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
SwingUtilities.invokeLater(()->{
|
||||
updateRewardGraph(rewardHistory);
|
||||
});
|
||||
SwingUtilities.invokeLater(() ->{
|
||||
if(drawEveryStep){
|
||||
updateRewardGraph(rewardHistory);
|
||||
}
|
||||
updateLearningInfoPanel();
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeStart() {
|
||||
if(drawEveryStep) {
|
||||
SwingUtilities.invokeLater(this::repaintEnvironment);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
if(drawEveryStep){
|
||||
SwingUtilities.invokeLater(this::repaintEnvironment);
|
||||
}
|
||||
}
|
||||
|
||||
private void repaintEnvironment(){
|
||||
if (environmentFrame != null) {
|
||||
environmentFrame.repaint();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,4 +5,5 @@ import java.util.List;
|
|||
public interface LearningListener{
|
||||
void onEpisodeEnd(List<Double> rewardHistory);
|
||||
void onEpisodeStart();
|
||||
void onStepEnd();
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package core.controller;
|
||||
package core.listener;
|
||||
|
||||
public interface ViewListener {
|
||||
void onEpsilonChange(float epsilon);
|
||||
void onDelayChange(int delay);
|
||||
void onFastLearnChange(boolean isFastLearn);
|
||||
}
|
|
@ -29,7 +29,6 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
|
|||
|
||||
@Override
|
||||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
System.out.println("current epsilon " + epsilon);
|
||||
if(RNG.getRandom().nextFloat() < epsilon){
|
||||
// Take random action
|
||||
return randomPolicy.chooseAction(actionValues);
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package evironment.antGame;
|
||||
|
||||
import core.*;
|
||||
import core.algo.Learning;
|
||||
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
||||
import evironment.antGame.gui.MainFrame;
|
||||
|
||||
import core.Environment;
|
||||
import core.State;
|
||||
import core.StepResultEnvironment;
|
||||
import core.gui.Visualizable;
|
||||
import evironment.antGame.gui.AntWorldComponent;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
|
||||
public class AntWorld implements Environment<AntAction>{
|
||||
public class AntWorld implements Environment<AntAction>, Visualizable {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
|
@ -204,14 +205,8 @@ public class AntWorld implements Environment<AntAction>{
|
|||
return myAnt;
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
RNG.setSeed(1993);
|
||||
|
||||
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
|
||||
new AntWorld(3, 3, 0.1),
|
||||
new ListDiscreteActionSpace<>(AntAction.values()),
|
||||
5
|
||||
);
|
||||
monteCarlo.learn(20000);
|
||||
@Override
|
||||
public JComponent visualize() {
|
||||
return new AntWorldComponent(this, this.antAgent);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,21 +1,18 @@
|
|||
package evironment.antGame.gui;
|
||||
|
||||
import core.StepResultEnvironment;
|
||||
import evironment.antGame.AntAction;
|
||||
import evironment.antGame.AntAgent;
|
||||
import evironment.antGame.AntWorld;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
|
||||
public class MainFrame extends JFrame {
|
||||
public class AntWorldComponent extends JComponent {
|
||||
private AntWorld antWorld;
|
||||
private HistoryPanel historyPanel;
|
||||
|
||||
public MainFrame(AntWorld antWorld, AntAgent antAgent){
|
||||
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
||||
this.antWorld = antWorld;
|
||||
setLayout(new BorderLayout());
|
||||
setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
||||
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
||||
historyPanel = new HistoryPanel();
|
||||
|
@ -29,14 +26,11 @@ public class MainFrame extends JFrame {
|
|||
|
||||
add(BorderLayout.CENTER, mapComponent);
|
||||
add(BorderLayout.SOUTH, historyPanel);
|
||||
pack();
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
|
||||
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
|
||||
antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
|
||||
|
||||
repaint();
|
||||
@Override
|
||||
protected void paintComponent(Graphics g) {
|
||||
super.paintComponent(g);
|
||||
}
|
||||
}
|
|
@ -8,14 +8,14 @@ import evironment.antGame.AntWorld;
|
|||
|
||||
public class RunningAnt {
|
||||
public static void main(String[] args) {
|
||||
RNG.setSeed(1234);
|
||||
RNG.setSeed(123);
|
||||
|
||||
RLController<AntAction> rl = new RLController<AntAction>()
|
||||
.setEnvironment(new AntWorld(3,3,0.1))
|
||||
.setAllowedActions(AntAction.values())
|
||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||
.setDelay(10)
|
||||
.setEpisodes(10000);
|
||||
.setEpisodes(100000);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue