refo/src/main/java/core/algo/Learning.java

108 lines
3.3 KiB
Java

package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.StateActionTable;
import core.listener.LearningListener;
import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
/**
*
* @param <A> discrete action type for a specific environment
*/
@Getter
public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;
protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace;
@Setter
protected StateActionTable<A> stateActionTable;
protected Environment<A> environment;
protected float discountFactor;
protected Set<LearningListener> learningListeners;
@Setter
protected int delay;
protected List<Double> rewardHistory;
protected volatile boolean currentlyLearning;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.delay = delay;
currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
public abstract void learn();
public void addListener(LearningListener learningListener) {
learningListeners.add(learningListener);
}
protected void dispatchStepEnd() {
for (LearningListener l : learningListeners) {
l.onStepEnd();
}
}
protected void dispatchLearningStart() {
currentlyLearning = true;
for (LearningListener l : learningListeners) {
l.onLearningStart();
}
}
protected void dispatchLearningEnd() {
currentlyLearning = false;
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
}
public synchronized void interruptLearning(){
//TODO: for non episodic learning
}
public void save(ObjectOutputStream oos) throws IOException {
oos.writeObject(rewardHistory);
// oos.writeObject(stateActionTable);
}
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
rewardHistory = (List<Double>) ois.readObject();
stateActionTable = (StateActionTable<A>) ois.readObject();
}
}