split GUI parts from controller into sub class
This commit is contained in:
parent
195722e98f
commit
518683b676
|
@ -5,4 +5,4 @@
|
|||
<profile default="true" name="Default" enabled="true" />
|
||||
</annotationProcessing>
|
||||
</component>
|
||||
</project>
|
||||
</project>
|
|
@ -10,6 +10,7 @@ import core.gui.View;
|
|||
import core.listener.LearningListener;
|
||||
import core.listener.ViewListener;
|
||||
import core.policy.EpsilonPolicy;
|
||||
import lombok.Setter;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.io.*;
|
||||
|
@ -18,24 +19,29 @@ import java.util.concurrent.ExecutorService;
|
|||
import java.util.concurrent.Executors;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||
private final String folderPrefix = "learningStates" + File.separator;
|
||||
private Environment<A> environment;
|
||||
private DiscreteActionSpace<A> discreteActionSpace;
|
||||
private Method method;
|
||||
private int delay = LearningConfig.DEFAULT_DELAY;
|
||||
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||
private Learning<A> learning;
|
||||
private LearningView learningView;
|
||||
private boolean fastLearning;
|
||||
private List<Double> latestRewardsHistory;
|
||||
private int nrOfEpisodes;
|
||||
private int prevDelay;
|
||||
protected final String folderPrefix = "learningStates" + File.separator;
|
||||
protected Environment<A> environment;
|
||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||
protected Method method;
|
||||
@Setter
|
||||
protected int delay = LearningConfig.DEFAULT_DELAY;
|
||||
@Setter
|
||||
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||
@Setter
|
||||
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||
protected Learning<A> learning;
|
||||
protected boolean fastLearning;
|
||||
protected List<Double> latestRewardsHistory;
|
||||
@Setter
|
||||
protected int nrOfEpisodes;
|
||||
protected int prevDelay;
|
||||
protected volatile boolean printNextEpisode;
|
||||
|
||||
public RLController(Environment<A> env, Method method, A... actions){
|
||||
setEnvironment(env);
|
||||
setMethod(method);
|
||||
setAllowedActions(actions);
|
||||
printNextEpisode = true;
|
||||
}
|
||||
|
||||
public void start(){
|
||||
|
@ -48,20 +54,29 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
default:
|
||||
throw new IllegalArgumentException("Undefined method");
|
||||
}
|
||||
|
||||
initGUI();
|
||||
System.out.println("Initialized learning: " + learning.getClass());
|
||||
initListeners();
|
||||
System.out.println("Set listeners");
|
||||
initLearning();
|
||||
}
|
||||
|
||||
private void initGUI(){
|
||||
SwingUtilities.invokeLater(()->{
|
||||
learningView = new View<>(learning, environment, this);
|
||||
learning.addListener(this);
|
||||
});
|
||||
protected void initListeners(){
|
||||
learning.addListener(this);
|
||||
new Thread(() -> {
|
||||
while (true){
|
||||
printNextEpisode = true;
|
||||
try {
|
||||
Thread.sleep(30*1000);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
private void initLearning(){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
}else{
|
||||
learning.learn();
|
||||
|
@ -78,7 +93,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}else{
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
}
|
||||
learningView.updateLearningInfoPanel();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -88,10 +102,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
try {
|
||||
fis = new FileInputStream(fileName);
|
||||
in = new ObjectInputStream(fis);
|
||||
System.out.println("interrup" + Thread.currentThread().getId());
|
||||
System.out.println("interrupt" + Thread.currentThread().getId());
|
||||
learning.interruptLearning();
|
||||
learning.load(in);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
in.close();
|
||||
} catch (IOException | ClassNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
|
@ -117,7 +130,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
public void onEpsilonChange(float epsilon) {
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}else{
|
||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||
}
|
||||
|
@ -128,9 +140,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
changeLearningDelay(delay);
|
||||
}
|
||||
|
||||
private void changeLearningDelay(int delay){
|
||||
protected void changeLearningDelay(int delay){
|
||||
learning.setDelay(delay);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -153,19 +164,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
|
||||
@Override
|
||||
public void onLearningEnd() {
|
||||
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
latestRewardsHistory = rewardHistory;
|
||||
SwingUtilities.invokeLater(() ->{
|
||||
if(!fastLearning){
|
||||
learningView.updateRewardGraph(latestRewardsHistory);
|
||||
}
|
||||
learningView.updateLearningInfoPanel();
|
||||
});
|
||||
System.out.println("Learning finished");
|
||||
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -174,12 +174,19 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
if(!fastLearning){
|
||||
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
latestRewardsHistory = rewardHistory;
|
||||
if(printNextEpisode){
|
||||
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
|
||||
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
|
||||
printNextEpisode = false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
}
|
||||
|
||||
|
||||
/*************************************************
|
||||
** SETTERS **
|
||||
|
@ -205,19 +212,4 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}
|
||||
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
||||
}
|
||||
|
||||
public void setDelay(int delay){
|
||||
this.delay = delay;
|
||||
}
|
||||
|
||||
public void setEpisodes(int nrOfEpisodes){
|
||||
this.nrOfEpisodes = nrOfEpisodes;
|
||||
}
|
||||
|
||||
public void setDiscountFactor(float discountFactor){
|
||||
this.discountFactor = discountFactor;
|
||||
}
|
||||
public void setEpsilon(float epsilon){
|
||||
this.epsilon = epsilon;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
package core.controller;
|
||||
|
||||
import core.Environment;
|
||||
import core.algo.Method;
|
||||
import core.gui.LearningView;
|
||||
import core.gui.View;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.util.List;
|
||||
|
||||
public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
||||
private LearningView learningView;
|
||||
|
||||
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
||||
super(env, method, actions);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void initListeners() {
|
||||
SwingUtilities.invokeLater(() -> {
|
||||
learningView = new View<>(learning, environment, this);
|
||||
learning.addListener(this);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes) {
|
||||
super.onLearnMoreEpisodes(nrOfEpisodes);
|
||||
learningView.updateLearningInfoPanel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLoadState(String fileName) {
|
||||
super.onLoadState(fileName);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpsilonChange(float epsilon) {
|
||||
super.onEpsilonChange(epsilon);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void changeLearningDelay(int delay) {
|
||||
super.changeLearningDelay(delay);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLearningEnd() {
|
||||
super.onLearningEnd();
|
||||
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
super.onEpisodeEnd(rewardHistory);
|
||||
SwingUtilities.invokeLater(() -> {
|
||||
if (!fastLearning) {
|
||||
learningView.updateRewardGraph(latestRewardsHistory);
|
||||
}
|
||||
learningView.updateLearningInfoPanel();
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
if (!fastLearning) {
|
||||
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package example;
|
|||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import core.controller.RLControllerGUI;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
|
||||
|
@ -10,15 +11,15 @@ public class JumpingDino {
|
|||
public static void main(String[] args) {
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
new DinoWorld(true, true),
|
||||
Method.MC_ONPOLICY_EGREEDY,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
rl.setDelay(0);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setEpisodes(5000);
|
||||
rl.setNrOfEpisodes(100000);
|
||||
|
||||
rl.start();
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ public class RunningAnt {
|
|||
AntAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
rl.setEpisodes(10000);
|
||||
rl.setNrOfEpisodes(10000);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
|
||||
|
|
Loading…
Reference in New Issue