package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.StateActionTable;
import core.listener.LearningListener;
import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
@Getter
public abstract class Learning implements Serializable {
protected Policy policy;
protected DiscreteActionSpace actionSpace;
@Setter
protected StateActionTable stateActionTable;
protected Environment environment;
protected float discountFactor;
protected Set learningListeners;
@Setter
protected int delay;
protected List rewardHistory;
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.delay = delay;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
}
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
}
public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
}
public Learning(Environment environment, DiscreteActionSpace actionSpace) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
public abstract void learn();
public void addListener(LearningListener learningListener) {
learningListeners.add(learningListener);
}
protected void dispatchStepEnd() {
for (LearningListener l : learningListeners) {
l.onStepEnd();
}
}
protected void dispatchLearningStart() {
for (LearningListener l : learningListeners) {
l.onLearningStart();
}
}
protected void dispatchLearningEnd() {
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
}
}