split GUI parts from controller into sub class
This commit is contained in:
parent
195722e98f
commit
518683b676
|
@ -5,4 +5,4 @@
|
||||||
<profile default="true" name="Default" enabled="true" />
|
<profile default="true" name="Default" enabled="true" />
|
||||||
</annotationProcessing>
|
</annotationProcessing>
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
|
@ -10,6 +10,7 @@ import core.gui.View;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
import core.listener.ViewListener;
|
import core.listener.ViewListener;
|
||||||
import core.policy.EpsilonPolicy;
|
import core.policy.EpsilonPolicy;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
@ -18,24 +19,29 @@ import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||||
private final String folderPrefix = "learningStates" + File.separator;
|
protected final String folderPrefix = "learningStates" + File.separator;
|
||||||
private Environment<A> environment;
|
protected Environment<A> environment;
|
||||||
private DiscreteActionSpace<A> discreteActionSpace;
|
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||||
private Method method;
|
protected Method method;
|
||||||
private int delay = LearningConfig.DEFAULT_DELAY;
|
@Setter
|
||||||
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
protected int delay = LearningConfig.DEFAULT_DELAY;
|
||||||
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
@Setter
|
||||||
private Learning<A> learning;
|
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||||
private LearningView learningView;
|
@Setter
|
||||||
private boolean fastLearning;
|
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||||
private List<Double> latestRewardsHistory;
|
protected Learning<A> learning;
|
||||||
private int nrOfEpisodes;
|
protected boolean fastLearning;
|
||||||
private int prevDelay;
|
protected List<Double> latestRewardsHistory;
|
||||||
|
@Setter
|
||||||
|
protected int nrOfEpisodes;
|
||||||
|
protected int prevDelay;
|
||||||
|
protected volatile boolean printNextEpisode;
|
||||||
|
|
||||||
public RLController(Environment<A> env, Method method, A... actions){
|
public RLController(Environment<A> env, Method method, A... actions){
|
||||||
setEnvironment(env);
|
setEnvironment(env);
|
||||||
setMethod(method);
|
setMethod(method);
|
||||||
setAllowedActions(actions);
|
setAllowedActions(actions);
|
||||||
|
printNextEpisode = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void start(){
|
public void start(){
|
||||||
|
@ -48,20 +54,29 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Undefined method");
|
throw new IllegalArgumentException("Undefined method");
|
||||||
}
|
}
|
||||||
|
System.out.println("Initialized learning: " + learning.getClass());
|
||||||
initGUI();
|
initListeners();
|
||||||
|
System.out.println("Set listeners");
|
||||||
initLearning();
|
initLearning();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initGUI(){
|
protected void initListeners(){
|
||||||
SwingUtilities.invokeLater(()->{
|
learning.addListener(this);
|
||||||
learningView = new View<>(learning, environment, this);
|
new Thread(() -> {
|
||||||
learning.addListener(this);
|
while (true){
|
||||||
});
|
printNextEpisode = true;
|
||||||
|
try {
|
||||||
|
Thread.sleep(30*1000);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initLearning(){
|
private void initLearning(){
|
||||||
if(learning instanceof EpisodicLearning){
|
if(learning instanceof EpisodicLearning){
|
||||||
|
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
|
||||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||||
}else{
|
}else{
|
||||||
learning.learn();
|
learning.learn();
|
||||||
|
@ -78,7 +93,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}else{
|
}else{
|
||||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||||
}
|
}
|
||||||
learningView.updateLearningInfoPanel();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -88,10 +102,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
try {
|
try {
|
||||||
fis = new FileInputStream(fileName);
|
fis = new FileInputStream(fileName);
|
||||||
in = new ObjectInputStream(fis);
|
in = new ObjectInputStream(fis);
|
||||||
System.out.println("interrup" + Thread.currentThread().getId());
|
System.out.println("interrupt" + Thread.currentThread().getId());
|
||||||
learning.interruptLearning();
|
learning.interruptLearning();
|
||||||
learning.load(in);
|
learning.load(in);
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
|
||||||
in.close();
|
in.close();
|
||||||
} catch (IOException | ClassNotFoundException e) {
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
|
@ -117,7 +130,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
public void onEpsilonChange(float epsilon) {
|
public void onEpsilonChange(float epsilon) {
|
||||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
|
||||||
}else{
|
}else{
|
||||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||||
}
|
}
|
||||||
|
@ -128,9 +140,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
changeLearningDelay(delay);
|
changeLearningDelay(delay);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void changeLearningDelay(int delay){
|
protected void changeLearningDelay(int delay){
|
||||||
learning.setDelay(delay);
|
learning.setDelay(delay);
|
||||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -153,19 +164,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onLearningEnd() {
|
public void onLearningEnd() {
|
||||||
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
System.out.println("Learning finished");
|
||||||
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
|
||||||
latestRewardsHistory = rewardHistory;
|
|
||||||
SwingUtilities.invokeLater(() ->{
|
|
||||||
if(!fastLearning){
|
|
||||||
learningView.updateRewardGraph(latestRewardsHistory);
|
|
||||||
}
|
|
||||||
learningView.updateLearningInfoPanel();
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -174,12 +174,19 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onStepEnd() {
|
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||||
if(!fastLearning){
|
latestRewardsHistory = rewardHistory;
|
||||||
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
if(printNextEpisode){
|
||||||
|
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
|
||||||
|
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
|
||||||
|
printNextEpisode = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onStepEnd() {
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*************************************************
|
/*************************************************
|
||||||
** SETTERS **
|
** SETTERS **
|
||||||
|
@ -205,19 +212,4 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setDelay(int delay){
|
|
||||||
this.delay = delay;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setEpisodes(int nrOfEpisodes){
|
|
||||||
this.nrOfEpisodes = nrOfEpisodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setDiscountFactor(float discountFactor){
|
|
||||||
this.discountFactor = discountFactor;
|
|
||||||
}
|
|
||||||
public void setEpsilon(float epsilon){
|
|
||||||
this.epsilon = epsilon;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
package core.controller;
|
||||||
|
|
||||||
|
import core.Environment;
|
||||||
|
import core.algo.Method;
|
||||||
|
import core.gui.LearningView;
|
||||||
|
import core.gui.View;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
||||||
|
private LearningView learningView;
|
||||||
|
|
||||||
|
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
||||||
|
super(env, method, actions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void initListeners() {
|
||||||
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
learningView = new View<>(learning, environment, this);
|
||||||
|
learning.addListener(this);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onLearnMoreEpisodes(int nrOfEpisodes) {
|
||||||
|
super.onLearnMoreEpisodes(nrOfEpisodes);
|
||||||
|
learningView.updateLearningInfoPanel();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onLoadState(String fileName) {
|
||||||
|
super.onLoadState(fileName);
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpsilonChange(float epsilon) {
|
||||||
|
super.onEpsilonChange(epsilon);
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void changeLearningDelay(int delay) {
|
||||||
|
super.changeLearningDelay(delay);
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onLearningEnd() {
|
||||||
|
super.onLearningEnd();
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||||
|
super.onEpisodeEnd(rewardHistory);
|
||||||
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
if (!fastLearning) {
|
||||||
|
learningView.updateRewardGraph(latestRewardsHistory);
|
||||||
|
}
|
||||||
|
learningView.updateLearningInfoPanel();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onStepEnd() {
|
||||||
|
if (!fastLearning) {
|
||||||
|
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,6 +3,7 @@ package example;
|
||||||
import core.RNG;
|
import core.RNG;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.controller.RLController;
|
import core.controller.RLController;
|
||||||
|
import core.controller.RLControllerGUI;
|
||||||
import evironment.jumpingDino.DinoAction;
|
import evironment.jumpingDino.DinoAction;
|
||||||
import evironment.jumpingDino.DinoWorld;
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
|
||||||
|
@ -10,15 +11,15 @@ public class JumpingDino {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(55);
|
RNG.setSeed(55);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLController<>(
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
new DinoWorld(true, true),
|
new DinoWorld(true, true),
|
||||||
Method.MC_ONPOLICY_EGREEDY,
|
Method.MC_ONPOLICY_EGREEDY,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(200);
|
rl.setDelay(0);
|
||||||
rl.setDiscountFactor(1f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(0.15f);
|
rl.setEpsilon(0.15f);
|
||||||
rl.setEpisodes(5000);
|
rl.setNrOfEpisodes(100000);
|
||||||
|
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ public class RunningAnt {
|
||||||
AntAction.values());
|
AntAction.values());
|
||||||
|
|
||||||
rl.setDelay(200);
|
rl.setDelay(200);
|
||||||
rl.setEpisodes(10000);
|
rl.setNrOfEpisodes(10000);
|
||||||
rl.setDiscountFactor(1f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(0.15f);
|
rl.setEpsilon(0.15f);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue