add features to gui to control learning and moving learning listener interface to controller
- Add metric to display episodes per second - view not implementing learning listener anymore, controller does. Controller is controlling all view actions based upon learning events. Reacts to view events via viewListener - add executor service for learning task - using instance of to distinguish between episodic learning and td learning - add feature to trigger more episodes - add checkboxes for smoothing graph, displaying last 100 rewards only and drawing environment - remove history panel from antworld gui
This commit is contained in:
parent
34e7e3fdd6
commit
b1246f62cc
|
@ -0,0 +1,15 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
public class Util {
|
||||||
|
public static boolean isNumeric(String strNum) {
|
||||||
|
if (strNum == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
double d = Double.parseDouble(strNum);
|
||||||
|
} catch (NumberFormatException nfe) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,4 +2,6 @@ package core.algo;
|
||||||
|
|
||||||
public interface Episodic {
|
public interface Episodic {
|
||||||
int getCurrentEpisode();
|
int getCurrentEpisode();
|
||||||
|
int getEpisodesToGo();
|
||||||
|
int getEpisodesPerSecond();
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,14 @@ package core.algo;
|
||||||
|
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
|
import core.listener.LearningListener;
|
||||||
|
|
||||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{
|
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{
|
||||||
protected int currentEpisode;
|
protected int currentEpisode;
|
||||||
|
protected int episodesToLearn;
|
||||||
|
protected volatile int episodePerSecond;
|
||||||
|
protected int episodeSumCurrentSecond;
|
||||||
|
private volatile boolean meseaureEpisodeBenchMark;
|
||||||
|
|
||||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
|
@ -22,8 +27,56 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
super(environment, actionSpace);
|
super(environment, actionSpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
||||||
|
++episodeSumCurrentSecond;
|
||||||
|
rewardHistory.add(recentSumOfRewards);
|
||||||
|
for(LearningListener l: learningListeners) {
|
||||||
|
l.onEpisodeEnd(rewardHistory);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void dispatchEpisodeStart(){
|
||||||
|
for(LearningListener l: learningListeners){
|
||||||
|
l.onEpisodeStart();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(){
|
||||||
|
learn(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void learn(int nrOfEpisodes){
|
||||||
|
meseaureEpisodeBenchMark = true;
|
||||||
|
new Thread(()->{
|
||||||
|
while(meseaureEpisodeBenchMark){
|
||||||
|
episodePerSecond = episodeSumCurrentSecond;
|
||||||
|
episodeSumCurrentSecond = 0;
|
||||||
|
try {
|
||||||
|
Thread.sleep(1000);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
|
episodesToLearn += nrOfEpisodes;
|
||||||
|
dispatchLearningStart();
|
||||||
|
for(int i=0; i < nrOfEpisodes; ++i){
|
||||||
|
nextEpisode();
|
||||||
|
}
|
||||||
|
dispatchLearningEnd();
|
||||||
|
meseaureEpisodeBenchMark = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract void nextEpisode();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getCurrentEpisode(){
|
public int getCurrentEpisode(){
|
||||||
return currentEpisode;
|
return currentEpisode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getEpisodesToGo(){
|
||||||
|
return episodesToLearn - currentEpisode;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,7 @@ public abstract class Learning<A extends Enum> {
|
||||||
protected Set<LearningListener> learningListeners;
|
protected Set<LearningListener> learningListeners;
|
||||||
@Setter
|
@Setter
|
||||||
protected int delay;
|
protected int delay;
|
||||||
private List<Double> rewardHistory;
|
protected List<Double> rewardHistory;
|
||||||
|
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay){
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay){
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
|
@ -47,29 +47,27 @@ public abstract class Learning<A extends Enum> {
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public abstract void learn();
|
||||||
public abstract void learn(int nrOfEpisodes);
|
|
||||||
|
|
||||||
public void addListener(LearningListener learningListener){
|
public void addListener(LearningListener learningListener){
|
||||||
learningListeners.add(learningListener);
|
learningListeners.add(learningListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
|
||||||
rewardHistory.add(recentSumOfRewards);
|
|
||||||
for(LearningListener l: learningListeners) {
|
|
||||||
l.onEpisodeEnd(rewardHistory);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void dispatchEpisodeStart(){
|
|
||||||
for(LearningListener l: learningListeners){
|
|
||||||
l.onEpisodeStart();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void dispatchStepEnd(){
|
protected void dispatchStepEnd(){
|
||||||
for(LearningListener l: learningListeners){
|
for(LearningListener l: learningListeners){
|
||||||
l.onStepEnd();
|
l.onStepEnd();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void dispatchLearningStart(){
|
||||||
|
for(LearningListener l: learningListeners){
|
||||||
|
l.onLearningStart();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void dispatchLearningEnd(){
|
||||||
|
for(LearningListener l: learningListeners){
|
||||||
|
l.onLearningEnd();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,27 +11,33 @@ import java.util.*;
|
||||||
* TODO: Major problem:
|
* TODO: Major problem:
|
||||||
* StateActionPairs are only unique accounting for their position in the episode.
|
* StateActionPairs are only unique accounting for their position in the episode.
|
||||||
* For example:
|
* For example:
|
||||||
*
|
* <p>
|
||||||
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
||||||
* image the agent does not collect the food and drops it to the start, the agent will receive
|
* image the agent does not collect the food and drops it to the start, the agent will receive
|
||||||
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
||||||
*
|
* <p>
|
||||||
* BUT image moving left from the starting position will have no impact on the state because
|
* BUT image moving left from the starting position will have no impact on the state because
|
||||||
* the agent ran into a wall. The known world stays the same.
|
* the agent ran into a wall. The known world stays the same.
|
||||||
* Taking an action after that will have the exact same state but a different action
|
* Taking an action after that will have the exact same state but a different action
|
||||||
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
|
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
|
||||||
* timestamp, summing up all remaining rewards will be -9...
|
* timestamp, summing up all remaining rewards will be -9...
|
||||||
*
|
* <p>
|
||||||
* How to encounter this problem?
|
* How to encounter this problem?
|
||||||
|
*
|
||||||
* @param <A>
|
* @param <A>
|
||||||
*/
|
*/
|
||||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
|
||||||
|
private Map<Pair<State, A>, Double> returnSum;
|
||||||
|
private Map<Pair<State, A>, Integer> returnCount;
|
||||||
|
|
||||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
currentEpisode = 0;
|
currentEpisode = 0;
|
||||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
|
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
|
||||||
|
returnSum = new HashMap<>();
|
||||||
|
returnCount = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
|
@ -40,12 +46,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(int nrOfEpisodes) {
|
public void nextEpisode() {
|
||||||
|
|
||||||
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
|
|
||||||
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
|
||||||
|
|
||||||
for(int i = 0; i < nrOfEpisodes; ++i) {
|
|
||||||
++currentEpisode;
|
++currentEpisode;
|
||||||
List<StepResult<A>> episode = new ArrayList<>();
|
List<StepResult<A>> episode = new ArrayList<>();
|
||||||
State state = environment.reset();
|
State state = environment.reset();
|
||||||
|
@ -77,7 +78,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatchEpisodeEnd(sumOfRewards);
|
dispatchEpisodeEnd(sumOfRewards);
|
||||||
System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards);
|
System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||||
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
||||||
|
|
||||||
for (StepResult<A> sr : episode) {
|
for (StepResult<A> sr : episode) {
|
||||||
|
@ -89,7 +90,6 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
// find first occurance of state action pair
|
// find first occurance of state action pair
|
||||||
for (StepResult<A> sr : episode) {
|
for (StepResult<A> sr : episode) {
|
||||||
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
||||||
;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
firstOccurenceIndex++;
|
firstOccurenceIndex++;
|
||||||
|
@ -106,10 +106,14 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getCurrentEpisode() {
|
public int getCurrentEpisode() {
|
||||||
return currentEpisode;
|
return currentEpisode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getEpisodesPerSecond(){
|
||||||
|
return episodePerSecond;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,26 +3,37 @@ package core.controller;
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
import core.ListDiscreteActionSpace;
|
import core.ListDiscreteActionSpace;
|
||||||
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
||||||
|
import core.gui.LearningView;
|
||||||
import core.gui.View;
|
import core.gui.View;
|
||||||
|
import core.listener.LearningListener;
|
||||||
import core.listener.ViewListener;
|
import core.listener.ViewListener;
|
||||||
import core.policy.EpsilonPolicy;
|
import core.policy.EpsilonPolicy;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
public class RLController<A extends Enum> implements ViewListener {
|
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||||
protected Environment<A> environment;
|
protected Environment<A> environment;
|
||||||
protected Learning<A> learning;
|
protected Learning<A> learning;
|
||||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||||
protected View<A> view;
|
protected LearningView learningView;
|
||||||
private int delay;
|
private int delay;
|
||||||
private int nrOfEpisodes;
|
private int nrOfEpisodes;
|
||||||
private Method method;
|
private Method method;
|
||||||
private int prevDelay;
|
private int prevDelay;
|
||||||
|
private boolean fastLearning;
|
||||||
|
private boolean currentlyLearning;
|
||||||
|
private ExecutorService learningExecutor;
|
||||||
|
private List<Double> latestRewardsHistory;
|
||||||
|
|
||||||
public RLController(){
|
public RLController(){
|
||||||
|
learningExecutor = Executors.newSingleThreadExecutor();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void start(){
|
public void start(){
|
||||||
|
@ -39,20 +50,37 @@ public class RLController<A extends Enum> implements ViewListener {
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Undefined method");
|
throw new RuntimeException("Undefined method");
|
||||||
}
|
}
|
||||||
/*
|
SwingUtilities.invokeLater(()->{
|
||||||
not using SwingUtilities here on purpose to ensure the view is fully
|
learningView = new View<>(learning, environment, this);
|
||||||
initialized and can be passed as LearningListener.
|
learning.addListener(this);
|
||||||
*/
|
});
|
||||||
view = new View<>(learning, environment, this);
|
|
||||||
learning.addListener(view);
|
if(learning instanceof EpisodicLearning){
|
||||||
learning.learn(nrOfEpisodes);
|
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||||
|
}else{
|
||||||
|
learningExecutor.submit(()->learning.learn());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*************************************************
|
||||||
|
* VIEW LISTENERS *
|
||||||
|
*************************************************/
|
||||||
|
@Override
|
||||||
|
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||||
|
if(!currentlyLearning){
|
||||||
|
if(learning instanceof EpisodicLearning){
|
||||||
|
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||||
|
}else{
|
||||||
|
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onEpsilonChange(float epsilon) {
|
public void onEpsilonChange(float epsilon) {
|
||||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
}else{
|
}else{
|
||||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||||
}
|
}
|
||||||
|
@ -65,12 +93,12 @@ public class RLController<A extends Enum> implements ViewListener {
|
||||||
|
|
||||||
private void changeLearningDelay(int delay){
|
private void changeLearningDelay(int delay){
|
||||||
learning.setDelay(delay);
|
learning.setDelay(delay);
|
||||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onFastLearnChange(boolean fastLearn) {
|
public void onFastLearnChange(boolean fastLearn) {
|
||||||
view.setDrawEveryStep(!fastLearn);
|
this.fastLearning = fastLearn;
|
||||||
if(fastLearn){
|
if(fastLearn){
|
||||||
prevDelay = learning.getDelay();
|
prevDelay = learning.getDelay();
|
||||||
changeLearningDelay(0);
|
changeLearningDelay(0);
|
||||||
|
@ -79,6 +107,45 @@ public class RLController<A extends Enum> implements ViewListener {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*************************************************
|
||||||
|
* LEARNING LISTENERS *
|
||||||
|
*************************************************/
|
||||||
|
@Override
|
||||||
|
public void onLearningStart() {
|
||||||
|
currentlyLearning = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onLearningEnd() {
|
||||||
|
currentlyLearning = false;
|
||||||
|
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||||
|
latestRewardsHistory = rewardHistory;
|
||||||
|
SwingUtilities.invokeLater(() ->{
|
||||||
|
if(!fastLearning){
|
||||||
|
learningView.updateRewardGraph(latestRewardsHistory);
|
||||||
|
}
|
||||||
|
learningView.updateLearningInfoPanel();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpisodeStart() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onStepEnd() {
|
||||||
|
if(!fastLearning){
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public RLController<A> setMethod(Method method){
|
public RLController<A> setMethod(Method method){
|
||||||
this.method = method;
|
this.method = method;
|
||||||
return this;
|
return this;
|
||||||
|
@ -102,5 +169,4 @@ public class RLController<A extends Enum> implements ViewListener {
|
||||||
this.nrOfEpisodes = nrOfEpisodes;
|
this.nrOfEpisodes = nrOfEpisodes;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
package core.gui;
|
package core.gui;
|
||||||
|
|
||||||
|
import core.Util;
|
||||||
import core.algo.Episodic;
|
import core.algo.Episodic;
|
||||||
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.listener.ViewListener;
|
import core.listener.ViewListener;
|
||||||
import core.policy.EpsilonPolicy;
|
import core.policy.EpsilonPolicy;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
|
||||||
public class LearningInfoPanel extends JPanel {
|
public class LearningInfoPanel extends JPanel {
|
||||||
private Learning learning;
|
private Learning learning;
|
||||||
|
@ -18,6 +21,11 @@ public class LearningInfoPanel extends JPanel {
|
||||||
private JSlider delaySlider;
|
private JSlider delaySlider;
|
||||||
private JButton toggleFastLearningButton;
|
private JButton toggleFastLearningButton;
|
||||||
private boolean fastLearning;
|
private boolean fastLearning;
|
||||||
|
private JCheckBox smoothGraphCheckbox;
|
||||||
|
private JCheckBox last100Checkbox;
|
||||||
|
private JCheckBox drawEnvironmentCheckbox;
|
||||||
|
private JTextField learnMoreEpisodesInput;
|
||||||
|
private JButton learnMoreEpisodesButton;
|
||||||
|
|
||||||
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
||||||
this.learning = learning;
|
this.learning = learning;
|
||||||
|
@ -47,11 +55,37 @@ public class LearningInfoPanel extends JPanel {
|
||||||
fastLearning = !fastLearning;
|
fastLearning = !fastLearning;
|
||||||
delaySlider.setEnabled(!fastLearning);
|
delaySlider.setEnabled(!fastLearning);
|
||||||
epsilonSlider.setEnabled(!fastLearning);
|
epsilonSlider.setEnabled(!fastLearning);
|
||||||
|
drawEnvironmentCheckbox.setSelected(!fastLearning);
|
||||||
viewListener.onFastLearnChange(fastLearning);
|
viewListener.onFastLearnChange(fastLearning);
|
||||||
});
|
});
|
||||||
|
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
|
||||||
|
smoothGraphCheckbox.setSelected(false);
|
||||||
|
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
|
||||||
|
last100Checkbox.setSelected(true);
|
||||||
|
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
|
||||||
|
drawEnvironmentCheckbox.setSelected(true);
|
||||||
|
|
||||||
add(delayLabel);
|
add(delayLabel);
|
||||||
add(delaySlider);
|
add(delaySlider);
|
||||||
add(toggleFastLearningButton);
|
add(toggleFastLearningButton);
|
||||||
|
|
||||||
|
if(learning instanceof EpisodicLearning) {
|
||||||
|
learnMoreEpisodesInput = new JTextField();
|
||||||
|
learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20));
|
||||||
|
learnMoreEpisodesButton = new JButton("Learn More Episodes");
|
||||||
|
learnMoreEpisodesButton.addActionListener(e -> {
|
||||||
|
if (Util.isNumeric(learnMoreEpisodesInput.getText())) {
|
||||||
|
viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText()));
|
||||||
|
} else {
|
||||||
|
learnMoreEpisodesInput.setText("");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
add(learnMoreEpisodesInput);
|
||||||
|
add(learnMoreEpisodesButton);
|
||||||
|
}
|
||||||
|
add(drawEnvironmentCheckbox);
|
||||||
|
add(smoothGraphCheckbox);
|
||||||
|
add(last100Checkbox);
|
||||||
refreshLabels();
|
refreshLabels();
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
}
|
}
|
||||||
|
@ -60,14 +94,29 @@ public class LearningInfoPanel extends JPanel {
|
||||||
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
||||||
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
||||||
if(learning instanceof Episodic){
|
if(learning instanceof Episodic){
|
||||||
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode());
|
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() +
|
||||||
|
"\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() +
|
||||||
|
"\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond());
|
||||||
}
|
}
|
||||||
if (learning.getPolicy() instanceof EpsilonPolicy) {
|
if (learning.getPolicy() instanceof EpsilonPolicy) {
|
||||||
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
|
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
|
||||||
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
|
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
|
||||||
}
|
}
|
||||||
delayLabel.setText("Delay (ms): " + learning.getDelay());
|
delayLabel.setText("Delay (ms): " + learning.getDelay());
|
||||||
|
if(delaySlider.isEnabled()){
|
||||||
delaySlider.setValue(learning.getDelay());
|
delaySlider.setValue(learning.getDelay());
|
||||||
|
}
|
||||||
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
|
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected boolean isSmoothenGraphSelected() {
|
||||||
|
return smoothGraphCheckbox.isSelected();
|
||||||
|
}
|
||||||
|
protected boolean isLast100Selected(){
|
||||||
|
return last100Checkbox.isSelected();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected boolean isDrawEnvironmentSelected(){
|
||||||
|
return drawEnvironmentCheckbox.isSelected();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package core.gui;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface LearningView {
|
||||||
|
void repaintEnvironment();
|
||||||
|
void updateLearningInfoPanel();
|
||||||
|
void updateRewardGraph(final List<Double> rewardHistory);
|
||||||
|
}
|
|
@ -3,7 +3,7 @@ package core.gui;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.listener.ViewListener;
|
import core.listener.ViewListener;
|
||||||
import core.listener.LearningListener;
|
import javafx.util.Pair;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.knowm.xchart.QuickChart;
|
import org.knowm.xchart.QuickChart;
|
||||||
import org.knowm.xchart.XChartPanel;
|
import org.knowm.xchart.XChartPanel;
|
||||||
|
@ -12,8 +12,9 @@ import org.knowm.xchart.XYChart;
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
public class View<A extends Enum> implements LearningListener {
|
public class View<A extends Enum> implements LearningView{
|
||||||
private Learning<A> learning;
|
private Learning<A> learning;
|
||||||
private Environment<A> environment;
|
private Environment<A> environment;
|
||||||
@Getter
|
@Getter
|
||||||
|
@ -25,14 +26,12 @@ public class View<A extends Enum> implements LearningListener {
|
||||||
private JFrame environmentFrame;
|
private JFrame environmentFrame;
|
||||||
private XChartPanel<XYChart> rewardChartPanel;
|
private XChartPanel<XYChart> rewardChartPanel;
|
||||||
private ViewListener viewListener;
|
private ViewListener viewListener;
|
||||||
private boolean drawEveryStep;
|
|
||||||
|
|
||||||
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
|
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
|
||||||
this.learning = learning;
|
this.learning = learning;
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
this.viewListener = viewListener;
|
this.viewListener = viewListener;
|
||||||
drawEveryStep = true;
|
initMainFrame();
|
||||||
SwingUtilities.invokeLater(this::initMainFrame);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initMainFrame() {
|
private void initMainFrame() {
|
||||||
|
@ -92,46 +91,62 @@ public class View<A extends Enum> implements LearningListener {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setDrawEveryStep(boolean drawEveryStep){
|
public void updateRewardGraph(final List<Double> rewardHistory) {
|
||||||
this.drawEveryStep = drawEveryStep;
|
List<Integer> xValues;
|
||||||
|
List<Double> yValues;
|
||||||
|
if(learningInfoPanel.isLast100Selected()){
|
||||||
|
yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size()));
|
||||||
|
xValues = new CopyOnWriteArrayList<>();
|
||||||
|
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i <rewardHistory.size(); ++i){
|
||||||
|
xValues.add(i);
|
||||||
|
}
|
||||||
|
}else{
|
||||||
|
if(learningInfoPanel.isSmoothenGraphSelected()){
|
||||||
|
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
|
||||||
|
xValues = XYvalues.getKey();
|
||||||
|
yValues = XYvalues.getValue();
|
||||||
|
}else{
|
||||||
|
xValues = null;
|
||||||
|
yValues = rewardHistory;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void updateRewardGraph(List<Double> rewardHistory) {
|
rewardChart.updateXYSeries("rewardHistory", xValues, yValues, null);
|
||||||
rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null);
|
|
||||||
rewardChartPanel.revalidate();
|
rewardChartPanel.revalidate();
|
||||||
rewardChartPanel.repaint();
|
rewardChartPanel.repaint();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original){
|
||||||
|
int totalXPoints = 100;
|
||||||
|
|
||||||
|
List<Integer> xValues = new CopyOnWriteArrayList<>();
|
||||||
|
List<Double> tmp = new CopyOnWriteArrayList<>();
|
||||||
|
int meanBatch = original.size() / totalXPoints;
|
||||||
|
if(meanBatch < 1){
|
||||||
|
meanBatch = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int idx = 0;
|
||||||
|
int batchIdx = 0;
|
||||||
|
double batchSum = 0;
|
||||||
|
for(Double x: original) {
|
||||||
|
++idx;
|
||||||
|
batchSum += x;
|
||||||
|
if (idx == 1 || ++batchIdx % meanBatch == 0) {
|
||||||
|
tmp.add(batchSum / meanBatch);
|
||||||
|
xValues.add(idx);
|
||||||
|
batchSum = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new Pair<>(xValues, tmp);
|
||||||
|
}
|
||||||
|
|
||||||
public void updateLearningInfoPanel() {
|
public void updateLearningInfoPanel() {
|
||||||
this.learningInfoPanel.refreshLabels();
|
this.learningInfoPanel.refreshLabels();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public void repaintEnvironment(){
|
||||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
|
||||||
SwingUtilities.invokeLater(() ->{
|
|
||||||
if(drawEveryStep){
|
|
||||||
updateRewardGraph(rewardHistory);
|
|
||||||
}
|
|
||||||
updateLearningInfoPanel();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onEpisodeStart() {
|
|
||||||
if(drawEveryStep) {
|
|
||||||
SwingUtilities.invokeLater(this::repaintEnvironment);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onStepEnd() {
|
|
||||||
if(drawEveryStep){
|
|
||||||
SwingUtilities.invokeLater(this::repaintEnvironment);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void repaintEnvironment(){
|
|
||||||
if (environmentFrame != null) {
|
|
||||||
environmentFrame.repaint();
|
environmentFrame.repaint();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,8 @@ package core.listener;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface LearningListener{
|
public interface LearningListener{
|
||||||
|
void onLearningStart();
|
||||||
|
void onLearningEnd();
|
||||||
void onEpisodeEnd(List<Double> rewardHistory);
|
void onEpisodeEnd(List<Double> rewardHistory);
|
||||||
void onEpisodeStart();
|
void onEpisodeStart();
|
||||||
void onStepEnd();
|
void onStepEnd();
|
||||||
|
|
|
@ -4,4 +4,5 @@ public interface ViewListener {
|
||||||
void onEpsilonChange(float epsilon);
|
void onEpsilonChange(float epsilon);
|
||||||
void onDelayChange(int delay);
|
void onDelayChange(int delay);
|
||||||
void onFastLearnChange(boolean isFastLearn);
|
void onFastLearnChange(boolean isFastLearn);
|
||||||
|
void onLearnMoreEpisodes(int nrOfEpisodes);
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,14 +8,12 @@ import java.awt.*;
|
||||||
|
|
||||||
public class AntWorldComponent extends JComponent {
|
public class AntWorldComponent extends JComponent {
|
||||||
private AntWorld antWorld;
|
private AntWorld antWorld;
|
||||||
private HistoryPanel historyPanel;
|
|
||||||
|
|
||||||
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
||||||
this.antWorld = antWorld;
|
this.antWorld = antWorld;
|
||||||
setLayout(new BorderLayout());
|
setLayout(new BorderLayout());
|
||||||
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
||||||
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
||||||
historyPanel = new HistoryPanel();
|
|
||||||
|
|
||||||
JComponent mapComponent = new JPanel();
|
JComponent mapComponent = new JPanel();
|
||||||
FlowLayout flowLayout = new FlowLayout();
|
FlowLayout flowLayout = new FlowLayout();
|
||||||
|
@ -23,9 +21,7 @@ public class AntWorldComponent extends JComponent {
|
||||||
mapComponent.setLayout(flowLayout);
|
mapComponent.setLayout(flowLayout);
|
||||||
mapComponent.add(worldPane);
|
mapComponent.add(worldPane);
|
||||||
mapComponent.add(antBrainPane);
|
mapComponent.add(antBrainPane);
|
||||||
|
|
||||||
add(BorderLayout.CENTER, mapComponent);
|
add(BorderLayout.CENTER, mapComponent);
|
||||||
add(BorderLayout.SOUTH, historyPanel);
|
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
package evironment.antGame.gui;
|
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.awt.*;
|
|
||||||
|
|
||||||
public class HistoryPanel extends JPanel {
|
|
||||||
private final int panelWidth = 1000;
|
|
||||||
private final int panelHeight = 300;
|
|
||||||
private JTextArea textArea;
|
|
||||||
|
|
||||||
public HistoryPanel(){
|
|
||||||
setPreferredSize(new Dimension(panelWidth, panelHeight));
|
|
||||||
textArea = new JTextArea();
|
|
||||||
textArea.setLineWrap(true);
|
|
||||||
textArea.setWrapStyleWord(true);
|
|
||||||
textArea.setEditable(false);
|
|
||||||
JScrollPane scrollBar = new JScrollPane(textArea, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS);
|
|
||||||
scrollBar.setPreferredSize(new Dimension(panelWidth, panelHeight));
|
|
||||||
add(scrollBar);
|
|
||||||
setVisible(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addText(String toAppend){
|
|
||||||
textArea.append(toAppend);
|
|
||||||
textArea.append("\n\n");
|
|
||||||
revalidate();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -14,7 +14,7 @@ public class RunningAnt {
|
||||||
.setEnvironment(new AntWorld(3,3,0.1))
|
.setEnvironment(new AntWorld(3,3,0.1))
|
||||||
.setAllowedActions(AntAction.values())
|
.setAllowedActions(AntAction.values())
|
||||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||||
.setDelay(10)
|
.setDelay(200)
|
||||||
.setEpisodes(100000);
|
.setEpisodes(100000);
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue