add features to gui to control learning and moving learning listener interface to controller
- Add metric to display episodes per second - view not implementing learning listener anymore, controller does. Controller is controlling all view actions based upon learning events. Reacts to view events via viewListener - add executor service for learning task - using instance of to distinguish between episodic learning and td learning - add feature to trigger more episodes - add checkboxes for smoothing graph, displaying last 100 rewards only and drawing environment - remove history panel from antworld gui
This commit is contained in:
parent
34e7e3fdd6
commit
b1246f62cc
|
@ -0,0 +1,15 @@
|
|||
package core;
|
||||
|
||||
public class Util {
|
||||
public static boolean isNumeric(String strNum) {
|
||||
if (strNum == null) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
double d = Double.parseDouble(strNum);
|
||||
} catch (NumberFormatException nfe) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
|
@ -2,4 +2,6 @@ package core.algo;
|
|||
|
||||
public interface Episodic {
|
||||
int getCurrentEpisode();
|
||||
int getEpisodesToGo();
|
||||
int getEpisodesPerSecond();
|
||||
}
|
||||
|
|
|
@ -2,9 +2,14 @@ package core.algo;
|
|||
|
||||
import core.DiscreteActionSpace;
|
||||
import core.Environment;
|
||||
import core.listener.LearningListener;
|
||||
|
||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{
|
||||
protected int currentEpisode;
|
||||
protected int episodesToLearn;
|
||||
protected volatile int episodePerSecond;
|
||||
protected int episodeSumCurrentSecond;
|
||||
private volatile boolean meseaureEpisodeBenchMark;
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
|
@ -22,8 +27,56 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
super(environment, actionSpace);
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
||||
++episodeSumCurrentSecond;
|
||||
rewardHistory.add(recentSumOfRewards);
|
||||
for(LearningListener l: learningListeners) {
|
||||
l.onEpisodeEnd(rewardHistory);
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeStart(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onEpisodeStart();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(){
|
||||
learn(0);
|
||||
}
|
||||
|
||||
public void learn(int nrOfEpisodes){
|
||||
meseaureEpisodeBenchMark = true;
|
||||
new Thread(()->{
|
||||
while(meseaureEpisodeBenchMark){
|
||||
episodePerSecond = episodeSumCurrentSecond;
|
||||
episodeSumCurrentSecond = 0;
|
||||
try {
|
||||
Thread.sleep(1000);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
episodesToLearn += nrOfEpisodes;
|
||||
dispatchLearningStart();
|
||||
for(int i=0; i < nrOfEpisodes; ++i){
|
||||
nextEpisode();
|
||||
}
|
||||
dispatchLearningEnd();
|
||||
meseaureEpisodeBenchMark = false;
|
||||
}
|
||||
|
||||
protected abstract void nextEpisode();
|
||||
|
||||
@Override
|
||||
public int getCurrentEpisode(){
|
||||
return currentEpisode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getEpisodesToGo(){
|
||||
return episodesToLearn - currentEpisode;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ public abstract class Learning<A extends Enum> {
|
|||
protected Set<LearningListener> learningListeners;
|
||||
@Setter
|
||||
protected int delay;
|
||||
private List<Double> rewardHistory;
|
||||
protected List<Double> rewardHistory;
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay){
|
||||
this.environment = environment;
|
||||
|
@ -47,29 +47,27 @@ public abstract class Learning<A extends Enum> {
|
|||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
||||
}
|
||||
|
||||
|
||||
public abstract void learn(int nrOfEpisodes);
|
||||
public abstract void learn();
|
||||
|
||||
public void addListener(LearningListener learningListener){
|
||||
learningListeners.add(learningListener);
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
||||
rewardHistory.add(recentSumOfRewards);
|
||||
for(LearningListener l: learningListeners) {
|
||||
l.onEpisodeEnd(rewardHistory);
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeStart(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onEpisodeStart();
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchStepEnd(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onStepEnd();
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchLearningStart(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onLearningStart();
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchLearningEnd(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onLearningEnd();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,27 +11,33 @@ import java.util.*;
|
|||
* TODO: Major problem:
|
||||
* StateActionPairs are only unique accounting for their position in the episode.
|
||||
* For example:
|
||||
*
|
||||
* <p>
|
||||
* startingState -> MOVE_LEFT : very first state action in the episode i = 1
|
||||
* image the agent does not collect the food and drops it to the start, the agent will receive
|
||||
* -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10;
|
||||
*
|
||||
* <p>
|
||||
* BUT image moving left from the starting position will have no impact on the state because
|
||||
* the agent ran into a wall. The known world stays the same.
|
||||
* Taking an action after that will have the exact same state but a different action
|
||||
* making the value of this stateActionPair -9 because the stateAction pair took place on the second
|
||||
* timestamp, summing up all remaining rewards will be -9...
|
||||
*
|
||||
* <p>
|
||||
* How to encounter this problem?
|
||||
*
|
||||
* @param <A>
|
||||
*/
|
||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
|
||||
private Map<Pair<State, A>, Double> returnSum;
|
||||
private Map<Pair<State, A>, Integer> returnCount;
|
||||
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
currentEpisode = 0;
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
|
||||
returnSum = new HashMap<>();
|
||||
returnCount = new HashMap<>();
|
||||
}
|
||||
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
|
@ -40,71 +46,64 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
|
||||
|
||||
@Override
|
||||
public void learn(int nrOfEpisodes) {
|
||||
public void nextEpisode() {
|
||||
++currentEpisode;
|
||||
List<StepResult<A>> episode = new ArrayList<>();
|
||||
State state = environment.reset();
|
||||
dispatchEpisodeStart();
|
||||
try {
|
||||
Thread.sleep(delay);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
double sumOfRewards = 0;
|
||||
for (int j = 0; j < 10; ++j) {
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A chosenAction = policy.chooseAction(actionValues);
|
||||
StepResultEnvironment envResult = environment.step(chosenAction);
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += envResult.getReward();
|
||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||
|
||||
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
|
||||
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
||||
if (envResult.isDone()) break;
|
||||
|
||||
state = nextState;
|
||||
|
||||
for(int i = 0; i < nrOfEpisodes; ++i) {
|
||||
++currentEpisode;
|
||||
List<StepResult<A>> episode = new ArrayList<>();
|
||||
State state = environment.reset();
|
||||
dispatchEpisodeStart();
|
||||
try {
|
||||
Thread.sleep(delay);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
double sumOfRewards = 0;
|
||||
for(int j=0; j < 10; ++j){
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A chosenAction = policy.chooseAction(actionValues);
|
||||
StepResultEnvironment envResult = environment.step(chosenAction);
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += envResult.getReward();
|
||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||
dispatchStepEnd();
|
||||
}
|
||||
|
||||
if(envResult.isDone()) break;
|
||||
dispatchEpisodeEnd(sumOfRewards);
|
||||
System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
||||
|
||||
state = nextState;
|
||||
|
||||
try {
|
||||
Thread.sleep(delay);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
for (StepResult<A> sr : episode) {
|
||||
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
||||
}
|
||||
System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||
for (Pair<State, A> stateActionPair : stateActionPairs) {
|
||||
int firstOccurenceIndex = 0;
|
||||
// find first occurance of state action pair
|
||||
for (StepResult<A> sr : episode) {
|
||||
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
||||
break;
|
||||
}
|
||||
dispatchStepEnd();
|
||||
firstOccurenceIndex++;
|
||||
}
|
||||
|
||||
dispatchEpisodeEnd(sumOfRewards);
|
||||
System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards);
|
||||
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
||||
|
||||
for(StepResult<A> sr: episode){
|
||||
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
||||
}
|
||||
System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||
for(Pair<State, A> stateActionPair: stateActionPairs){
|
||||
int firstOccurenceIndex = 0;
|
||||
// find first occurance of state action pair
|
||||
for(StepResult<A> sr: episode){
|
||||
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){
|
||||
;
|
||||
break;
|
||||
}
|
||||
firstOccurenceIndex++;
|
||||
}
|
||||
|
||||
double G = 0;
|
||||
for(int l = firstOccurenceIndex; l < episode.size(); ++l){
|
||||
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
|
||||
}
|
||||
// slick trick to add G to the entry.
|
||||
// if the key does not exists, it will create a new entry with G as default value
|
||||
returnSum.merge(stateActionPair, G, Double::sum);
|
||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
double G = 0;
|
||||
for (int l = firstOccurenceIndex; l < episode.size(); ++l) {
|
||||
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
|
||||
}
|
||||
// slick trick to add G to the entry.
|
||||
// if the key does not exists, it will create a new entry with G as default value
|
||||
returnSum.merge(stateActionPair, G, Double::sum);
|
||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,4 +111,9 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
public int getCurrentEpisode() {
|
||||
return currentEpisode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getEpisodesPerSecond(){
|
||||
return episodePerSecond;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,26 +3,37 @@ package core.controller;
|
|||
import core.DiscreteActionSpace;
|
||||
import core.Environment;
|
||||
import core.ListDiscreteActionSpace;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.algo.Learning;
|
||||
import core.algo.Method;
|
||||
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
||||
import core.gui.LearningView;
|
||||
import core.gui.View;
|
||||
import core.listener.LearningListener;
|
||||
import core.listener.ViewListener;
|
||||
import core.policy.EpsilonPolicy;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener {
|
||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||
protected Environment<A> environment;
|
||||
protected Learning<A> learning;
|
||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||
protected View<A> view;
|
||||
protected LearningView learningView;
|
||||
private int delay;
|
||||
private int nrOfEpisodes;
|
||||
private Method method;
|
||||
private int prevDelay;
|
||||
private boolean fastLearning;
|
||||
private boolean currentlyLearning;
|
||||
private ExecutorService learningExecutor;
|
||||
private List<Double> latestRewardsHistory;
|
||||
|
||||
public RLController(){
|
||||
learningExecutor = Executors.newSingleThreadExecutor();
|
||||
}
|
||||
|
||||
public void start(){
|
||||
|
@ -39,20 +50,37 @@ public class RLController<A extends Enum> implements ViewListener {
|
|||
default:
|
||||
throw new RuntimeException("Undefined method");
|
||||
}
|
||||
/*
|
||||
not using SwingUtilities here on purpose to ensure the view is fully
|
||||
initialized and can be passed as LearningListener.
|
||||
*/
|
||||
view = new View<>(learning, environment, this);
|
||||
learning.addListener(view);
|
||||
learning.learn(nrOfEpisodes);
|
||||
SwingUtilities.invokeLater(()->{
|
||||
learningView = new View<>(learning, environment, this);
|
||||
learning.addListener(this);
|
||||
});
|
||||
|
||||
if(learning instanceof EpisodicLearning){
|
||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||
}else{
|
||||
learningExecutor.submit(()->learning.learn());
|
||||
}
|
||||
}
|
||||
|
||||
/*************************************************
|
||||
* VIEW LISTENERS *
|
||||
*************************************************/
|
||||
@Override
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||
if(!currentlyLearning){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||
}else{
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpsilonChange(float epsilon) {
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}else{
|
||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||
}
|
||||
|
@ -65,12 +93,12 @@ public class RLController<A extends Enum> implements ViewListener {
|
|||
|
||||
private void changeLearningDelay(int delay){
|
||||
learning.setDelay(delay);
|
||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFastLearnChange(boolean fastLearn) {
|
||||
view.setDrawEveryStep(!fastLearn);
|
||||
this.fastLearning = fastLearn;
|
||||
if(fastLearn){
|
||||
prevDelay = learning.getDelay();
|
||||
changeLearningDelay(0);
|
||||
|
@ -79,6 +107,45 @@ public class RLController<A extends Enum> implements ViewListener {
|
|||
}
|
||||
}
|
||||
|
||||
/*************************************************
|
||||
* LEARNING LISTENERS *
|
||||
*************************************************/
|
||||
@Override
|
||||
public void onLearningStart() {
|
||||
currentlyLearning = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLearningEnd() {
|
||||
currentlyLearning = false;
|
||||
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
latestRewardsHistory = rewardHistory;
|
||||
SwingUtilities.invokeLater(() ->{
|
||||
if(!fastLearning){
|
||||
learningView.updateRewardGraph(latestRewardsHistory);
|
||||
}
|
||||
learningView.updateLearningInfoPanel();
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeStart() {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
if(!fastLearning){
|
||||
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
public RLController<A> setMethod(Method method){
|
||||
this.method = method;
|
||||
return this;
|
||||
|
@ -102,5 +169,4 @@ public class RLController<A extends Enum> implements ViewListener {
|
|||
this.nrOfEpisodes = nrOfEpisodes;
|
||||
return this;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
package core.gui;
|
||||
|
||||
import core.Util;
|
||||
import core.algo.Episodic;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.algo.Learning;
|
||||
import core.listener.ViewListener;
|
||||
import core.policy.EpsilonPolicy;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
|
||||
public class LearningInfoPanel extends JPanel {
|
||||
private Learning learning;
|
||||
|
@ -18,6 +21,11 @@ public class LearningInfoPanel extends JPanel {
|
|||
private JSlider delaySlider;
|
||||
private JButton toggleFastLearningButton;
|
||||
private boolean fastLearning;
|
||||
private JCheckBox smoothGraphCheckbox;
|
||||
private JCheckBox last100Checkbox;
|
||||
private JCheckBox drawEnvironmentCheckbox;
|
||||
private JTextField learnMoreEpisodesInput;
|
||||
private JButton learnMoreEpisodesButton;
|
||||
|
||||
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
||||
this.learning = learning;
|
||||
|
@ -47,11 +55,37 @@ public class LearningInfoPanel extends JPanel {
|
|||
fastLearning = !fastLearning;
|
||||
delaySlider.setEnabled(!fastLearning);
|
||||
epsilonSlider.setEnabled(!fastLearning);
|
||||
drawEnvironmentCheckbox.setSelected(!fastLearning);
|
||||
viewListener.onFastLearnChange(fastLearning);
|
||||
});
|
||||
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
|
||||
smoothGraphCheckbox.setSelected(false);
|
||||
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
|
||||
last100Checkbox.setSelected(true);
|
||||
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
|
||||
drawEnvironmentCheckbox.setSelected(true);
|
||||
|
||||
add(delayLabel);
|
||||
add(delaySlider);
|
||||
add(toggleFastLearningButton);
|
||||
|
||||
if(learning instanceof EpisodicLearning) {
|
||||
learnMoreEpisodesInput = new JTextField();
|
||||
learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20));
|
||||
learnMoreEpisodesButton = new JButton("Learn More Episodes");
|
||||
learnMoreEpisodesButton.addActionListener(e -> {
|
||||
if (Util.isNumeric(learnMoreEpisodesInput.getText())) {
|
||||
viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText()));
|
||||
} else {
|
||||
learnMoreEpisodesInput.setText("");
|
||||
}
|
||||
});
|
||||
add(learnMoreEpisodesInput);
|
||||
add(learnMoreEpisodesButton);
|
||||
}
|
||||
add(drawEnvironmentCheckbox);
|
||||
add(smoothGraphCheckbox);
|
||||
add(last100Checkbox);
|
||||
refreshLabels();
|
||||
setVisible(true);
|
||||
}
|
||||
|
@ -60,14 +94,29 @@ public class LearningInfoPanel extends JPanel {
|
|||
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
||||
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
||||
if(learning instanceof Episodic){
|
||||
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode());
|
||||
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() +
|
||||
"\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() +
|
||||
"\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond());
|
||||
}
|
||||
if (learning.getPolicy() instanceof EpsilonPolicy) {
|
||||
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
|
||||
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
|
||||
}
|
||||
delayLabel.setText("Delay (ms): " + learning.getDelay());
|
||||
delaySlider.setValue(learning.getDelay());
|
||||
if(delaySlider.isEnabled()){
|
||||
delaySlider.setValue(learning.getDelay());
|
||||
}
|
||||
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
|
||||
}
|
||||
|
||||
protected boolean isSmoothenGraphSelected() {
|
||||
return smoothGraphCheckbox.isSelected();
|
||||
}
|
||||
protected boolean isLast100Selected(){
|
||||
return last100Checkbox.isSelected();
|
||||
}
|
||||
|
||||
protected boolean isDrawEnvironmentSelected(){
|
||||
return drawEnvironmentCheckbox.isSelected();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
package core.gui;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface LearningView {
|
||||
void repaintEnvironment();
|
||||
void updateLearningInfoPanel();
|
||||
void updateRewardGraph(final List<Double> rewardHistory);
|
||||
}
|
|
@ -3,7 +3,7 @@ package core.gui;
|
|||
import core.Environment;
|
||||
import core.algo.Learning;
|
||||
import core.listener.ViewListener;
|
||||
import core.listener.LearningListener;
|
||||
import javafx.util.Pair;
|
||||
import lombok.Getter;
|
||||
import org.knowm.xchart.QuickChart;
|
||||
import org.knowm.xchart.XChartPanel;
|
||||
|
@ -12,8 +12,9 @@ import org.knowm.xchart.XYChart;
|
|||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
public class View<A extends Enum> implements LearningListener {
|
||||
public class View<A extends Enum> implements LearningView{
|
||||
private Learning<A> learning;
|
||||
private Environment<A> environment;
|
||||
@Getter
|
||||
|
@ -25,14 +26,12 @@ public class View<A extends Enum> implements LearningListener {
|
|||
private JFrame environmentFrame;
|
||||
private XChartPanel<XYChart> rewardChartPanel;
|
||||
private ViewListener viewListener;
|
||||
private boolean drawEveryStep;
|
||||
|
||||
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
|
||||
this.learning = learning;
|
||||
this.environment = environment;
|
||||
this.viewListener = viewListener;
|
||||
drawEveryStep = true;
|
||||
SwingUtilities.invokeLater(this::initMainFrame);
|
||||
initMainFrame();
|
||||
}
|
||||
|
||||
private void initMainFrame() {
|
||||
|
@ -92,46 +91,62 @@ public class View<A extends Enum> implements LearningListener {
|
|||
};
|
||||
}
|
||||
|
||||
public void setDrawEveryStep(boolean drawEveryStep){
|
||||
this.drawEveryStep = drawEveryStep;
|
||||
}
|
||||
public void updateRewardGraph(final List<Double> rewardHistory) {
|
||||
List<Integer> xValues;
|
||||
List<Double> yValues;
|
||||
if(learningInfoPanel.isLast100Selected()){
|
||||
yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size()));
|
||||
xValues = new CopyOnWriteArrayList<>();
|
||||
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i <rewardHistory.size(); ++i){
|
||||
xValues.add(i);
|
||||
}
|
||||
}else{
|
||||
if(learningInfoPanel.isSmoothenGraphSelected()){
|
||||
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
|
||||
xValues = XYvalues.getKey();
|
||||
yValues = XYvalues.getValue();
|
||||
}else{
|
||||
xValues = null;
|
||||
yValues = rewardHistory;
|
||||
}
|
||||
}
|
||||
|
||||
public void updateRewardGraph(List<Double> rewardHistory) {
|
||||
rewardChart.updateXYSeries("rewardHistory", null, rewardHistory, null);
|
||||
rewardChart.updateXYSeries("rewardHistory", xValues, yValues, null);
|
||||
rewardChartPanel.revalidate();
|
||||
rewardChartPanel.repaint();
|
||||
}
|
||||
|
||||
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original){
|
||||
int totalXPoints = 100;
|
||||
|
||||
List<Integer> xValues = new CopyOnWriteArrayList<>();
|
||||
List<Double> tmp = new CopyOnWriteArrayList<>();
|
||||
int meanBatch = original.size() / totalXPoints;
|
||||
if(meanBatch < 1){
|
||||
meanBatch = 1;
|
||||
}
|
||||
|
||||
int idx = 0;
|
||||
int batchIdx = 0;
|
||||
double batchSum = 0;
|
||||
for(Double x: original) {
|
||||
++idx;
|
||||
batchSum += x;
|
||||
if (idx == 1 || ++batchIdx % meanBatch == 0) {
|
||||
tmp.add(batchSum / meanBatch);
|
||||
xValues.add(idx);
|
||||
batchSum = 0;
|
||||
}
|
||||
}
|
||||
return new Pair<>(xValues, tmp);
|
||||
}
|
||||
|
||||
public void updateLearningInfoPanel() {
|
||||
this.learningInfoPanel.refreshLabels();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
SwingUtilities.invokeLater(() ->{
|
||||
if(drawEveryStep){
|
||||
updateRewardGraph(rewardHistory);
|
||||
}
|
||||
updateLearningInfoPanel();
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeStart() {
|
||||
if(drawEveryStep) {
|
||||
SwingUtilities.invokeLater(this::repaintEnvironment);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
if(drawEveryStep){
|
||||
SwingUtilities.invokeLater(this::repaintEnvironment);
|
||||
}
|
||||
}
|
||||
|
||||
private void repaintEnvironment(){
|
||||
if (environmentFrame != null) {
|
||||
public void repaintEnvironment(){
|
||||
if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
|
||||
environmentFrame.repaint();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package core.listener;
|
|||
import java.util.List;
|
||||
|
||||
public interface LearningListener{
|
||||
void onLearningStart();
|
||||
void onLearningEnd();
|
||||
void onEpisodeEnd(List<Double> rewardHistory);
|
||||
void onEpisodeStart();
|
||||
void onStepEnd();
|
||||
|
|
|
@ -4,4 +4,5 @@ public interface ViewListener {
|
|||
void onEpsilonChange(float epsilon);
|
||||
void onDelayChange(int delay);
|
||||
void onFastLearnChange(boolean isFastLearn);
|
||||
void onLearnMoreEpisodes(int nrOfEpisodes);
|
||||
}
|
||||
|
|
|
@ -8,14 +8,12 @@ import java.awt.*;
|
|||
|
||||
public class AntWorldComponent extends JComponent {
|
||||
private AntWorld antWorld;
|
||||
private HistoryPanel historyPanel;
|
||||
|
||||
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
|
||||
this.antWorld = antWorld;
|
||||
setLayout(new BorderLayout());
|
||||
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
|
||||
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);
|
||||
historyPanel = new HistoryPanel();
|
||||
|
||||
JComponent mapComponent = new JPanel();
|
||||
FlowLayout flowLayout = new FlowLayout();
|
||||
|
@ -23,9 +21,7 @@ public class AntWorldComponent extends JComponent {
|
|||
mapComponent.setLayout(flowLayout);
|
||||
mapComponent.add(worldPane);
|
||||
mapComponent.add(antBrainPane);
|
||||
|
||||
add(BorderLayout.CENTER, mapComponent);
|
||||
add(BorderLayout.SOUTH, historyPanel);
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
package evironment.antGame.gui;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
|
||||
public class HistoryPanel extends JPanel {
|
||||
private final int panelWidth = 1000;
|
||||
private final int panelHeight = 300;
|
||||
private JTextArea textArea;
|
||||
|
||||
public HistoryPanel(){
|
||||
setPreferredSize(new Dimension(panelWidth, panelHeight));
|
||||
textArea = new JTextArea();
|
||||
textArea.setLineWrap(true);
|
||||
textArea.setWrapStyleWord(true);
|
||||
textArea.setEditable(false);
|
||||
JScrollPane scrollBar = new JScrollPane(textArea, JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS);
|
||||
scrollBar.setPreferredSize(new Dimension(panelWidth, panelHeight));
|
||||
add(scrollBar);
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
public void addText(String toAppend){
|
||||
textArea.append(toAppend);
|
||||
textArea.append("\n\n");
|
||||
revalidate();
|
||||
}
|
||||
}
|
|
@ -14,7 +14,7 @@ public class RunningAnt {
|
|||
.setEnvironment(new AntWorld(3,3,0.1))
|
||||
.setAllowedActions(AntAction.values())
|
||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||
.setDelay(10)
|
||||
.setDelay(200)
|
||||
.setEpisodes(100000);
|
||||
rl.start();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue