package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
import lombok.Getter;
import lombok.Setter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning extends Learning implements Episodic {
    @Setter
    protected int currentEpisode = 0;
    protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
    @Getter
    protected volatile int episodePerSecond;
    protected int episodeSumCurrentSecond;
    protected double sumOfRewards;
    protected List> episode = new ArrayList<>();
    public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
        super(environment, actionSpace, discountFactor, delay);
        initBenchMarking();
    }
    public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
        super(environment, actionSpace, discountFactor);
        initBenchMarking();
    }
    public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) {
        super(environment, actionSpace, delay);
        initBenchMarking();
    }
    public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) {
        super(environment, actionSpace);
        initBenchMarking();
    }
    protected abstract void nextEpisode();
    private void initBenchMarking(){
        new Thread(()->{
            while (true){
                episodePerSecond = episodeSumCurrentSecond;
                episodeSumCurrentSecond = 0;
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }).start();
    }
    protected void dispatchEpisodeEnd(){
        ++episodeSumCurrentSecond;
        if(rewardHistory.size() > 10000){
            rewardHistory.clear();
        }
        rewardHistory.add(sumOfRewards);
        for(LearningListener l: learningListeners) {
            l.onEpisodeEnd(rewardHistory);
        }
    }
    protected void dispatchEpisodeStart(){
        ++currentEpisode;
        episodesToLearn.decrementAndGet();
        for(LearningListener l: learningListeners){
            l.onEpisodeStart();
        }
    }
    @Override
    public void learn(){
        learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
    }
    private void startLearning(){
        learningExecutor.submit(()->{
            dispatchLearningStart();
            while(episodesToLearn.get() > 0){
                dispatchEpisodeStart();
                nextEpisode();
                dispatchEpisodeEnd();
            }
            synchronized (this){
                dispatchLearningEnd();
                notifyAll();
            }
        });
    }
    /**
     * Stopping the while loop by setting episodesToLearn to 0.
     * The current episode can not be interrupted, so the sleep delay
     * is removed and the calling thread has to wait until the
     * current episode is done.
     * Resetting the delay afterwards.
     */
    @Override
    public synchronized void interruptLearning(){
        episodesToLearn.set(0);
        int prevDelay = delay;
        delay = 0;
        while(currentlyLearning) {
            try {
                wait();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
        delay = prevDelay;
    }
    public synchronized void learn(int nrOfEpisodes){
        boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
        if(!isLearning)
            startLearning();
    }
    @Override
    public int getEpisodesToGo(){
        return episodesToLearn.get();
    }
    public int getCurrentEpisode() {
        return currentEpisode;
    }
    public int getEpisodesPerSecond() {
        return episodePerSecond;
    }
    @Override
    public synchronized void save(ObjectOutputStream oos) throws IOException {
      super.save(oos);
      oos.writeInt(currentEpisode);
    }
    @Override
    public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
       super.load(ois);
       currentEpisode = ois.readInt();
    }
}