import os
from statistics import mean

import numpy as np
from modules.train.TrainHelper import run_episode


class EvalCallback:
    def __init__(self, env, out_dir, every_n_steps = 2048, n_tests = 10, n_max_steps=np.inf, skip=[np.inf, np.inf], stop_at=np.inf) -> None:
        self.every_n_steps = every_n_steps
        self.n_tests = n_tests
        self.n_max_steps = n_max_steps
        self.out_dir = out_dir
        os.makedirs(self.out_dir, exist_ok=True)
        self.n_steps = 0
        self.scores = []
        self.episodes = []
        self.env = env
        self.skip = skip
        self.stop_at = stop_at

    def __call__(self, agent, test=False):
        if self.n_steps > self.stop_at:
            return
        if test: # or self.n_steps % self.every_n_steps == 0 and (self.n_steps < self.skip[0] or self.n_steps > self.skip[1]):
            results = []
            for _ in range(self.n_tests):
                _, _, rewards = run_episode(self.env, agent, False, self.n_max_steps, None)
                results.append(rewards.sum())
            self.scores.append(results)
            self.episodes.append(self.n_steps)
            print(f"[{self.n_steps}] Score: {mean(results)} +/- {np.std(results)}\n")

            np.save(f"{self.out_dir}/episodes", self.
            episodes)
            np.save(f"{self.out_dir}/scores", self.scores)

        self.n_steps += 1