refo/src/main/java/core/controller/RLController.java

217 lines
7.1 KiB
Java

package core.controller;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloControlEGreedy;
import core.algo.td.QLearningOffPolicyTDControl;
import core.algo.td.SARSA;
import core.listener.LearningListener;
import core.policy.EpsilonPolicy;
import lombok.Setter;
import java.io.*;
import java.util.List;
public class RLController<A extends Enum> implements LearningListener {
protected final String folderPrefix = "learningStates" + File.separator;
protected Environment<A> environment;
protected DiscreteActionSpace<A> discreteActionSpace;
protected Method method;
@Setter
protected int delay = LearningConfig.DEFAULT_DELAY;
@Setter
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
@Setter
protected float learningRate = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
@Setter
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
protected Learning<A> learning;
protected boolean fastLearning;
protected List<Double> latestRewardsHistory;
@Setter
protected int nrOfEpisodes;
protected int prevDelay;
protected volatile boolean printNextEpisode;
@SafeVarargs
public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
printNextEpisode = true;
}
public void start() {
switch(method) {
case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case MC_CONTROL_EVERY_VISIT:
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
break;
case SARSA_ON_POLICY_CONTROL:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
case Q_LEARNING_OFF_POLICY_CONTROL:
learning = new QLearningOffPolicyTDControl<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
default:
throw new IllegalArgumentException("Undefined method");
}
System.out.println("Initialized learning: " + learning.getClass());
initListeners();
System.out.println("Set listeners");
initLearning();
}
protected void initListeners() {
learning.addListener(this);
new Thread(() -> {
while(learning.isCurrentlyLearning()) {
printNextEpisode = true;
try {
Thread.sleep(30 * 1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}).start();
}
private void initLearning() {
if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
} else {
learning.learn();
}
}
protected void changeLearningDelay(int delay) {
learning.setDelay(delay);
}
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
if(learning.isCurrentlyLearning()){
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
}
protected void changeEpsilon(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy) {
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
} else {
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
}
protected void loadState(String fileName) {
FileInputStream fis;
ObjectInputStream in;
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
System.out.println("interrupt" + Thread.currentThread().getId());
learning.interruptLearning();
learning.load(in);
in.close();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
}
}
protected void saveState(String fileName) {
FileOutputStream fos;
ObjectOutputStream out;
try {
fos = new FileOutputStream(folderPrefix + fileName);
out = new ObjectOutputStream(fos);
learning.interruptLearning();
learning.save(out);
out.close();
} catch (IOException e) {
e.printStackTrace();
}
}
protected void changeFastLearning(boolean fastLearn) {
this.fastLearning = fastLearn;
if(fastLearn) {
prevDelay = learning.getDelay();
changeLearningDelay(0);
} else {
changeLearningDelay(prevDelay);
}
}
/*************************************************
** LEARNING LISTENERS **
*************************************************/
@Override
public void onLearningStart() {
}
@Override
public void onLearningEnd() {
System.out.println("Learning finished");
}
@Override
public void onEpisodeStart() {
}
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}
@Override
public void onStepEnd() {
}
/*************************************************
** SETTERS **
*************************************************/
private void setEnvironment(Environment<A> environment) {
if(environment == null) {
throw new IllegalArgumentException("Environment cannot be null");
}
this.environment = environment;
}
private void setMethod(Method method) {
if(method == null) {
throw new IllegalArgumentException("Method cannot be null");
}
this.method = method;
}
private void setAllowedActions(A[] actions) {
if(actions == null || actions.length == 0) {
throw new IllegalArgumentException("There has to be at least one action");
}
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
}
}