import gymnasium as gym
import numpy as np
from typing import Optional, Callable, Any
from keras_progbar import Progbar
from scipy.interpolate import interp1d
from rl.utils.risk_utils import wang, pow


class BaseOffline(object):
    """
    An interface class for the offline reinforcement learning.
    Do nothing, but providing some functions for the specific train procedures
    """
    np_rng: np.random.Generator
    env: gym.Env
    policy: Any

    def predict(self, observation, *args, **kwargs) -> np.ndarray:
        pass

    def train_step(self, *args, **kwargs) -> dict:
        pass

    def score_metric(self, array):
        return np.mean(array), np.std(array)

    def epoch_learn(self, len_epoch):
        progbar = Progbar(len_epoch)
        for _ in range(len_epoch):
            loss_info = self.train_step()

            progbar.add(1, [(k, v) for k, v in loss_info.items()])

    def train(self,
              epoch: int,
              len_epoch: int = 1000,
              eval_interval: int = 5,
              n_eval: int = 10,
              normalizer: Optional[Callable] = None,
              ) -> None:

        for e in range(epoch):
            self.epoch_learn(len_epoch)
            if e % eval_interval == 0:
                scores = self.evaluate(n_eval, self.env)
                mean, std = np.mean(scores), np.std(scores)

                print(f"EPOCH {e}:::")
                if normalizer is not None:
                    print(f"SCORE {mean} +/- {std}:  NORMALIZED {normalizer(mean) * 100:.2f}%")
                else:
                    print(f"SCORE {mean} +/- {std}")
        scores = self.evaluate(n_eval, self.env)
        mean, std = np.mean(scores), np.std(scores)

        if normalizer is not None:
            print(f"SCORE {mean} +/- {std}: NORMALIZED {normalizer(mean) * 100:.2f}%")
        else:
            print(f"SCORE {mean} +/- {std}")

    def evaluate(self, n_eval: int, env):
        scores = []
        for _ in range(n_eval):
            seed = self.np_rng.integers(0, 2 ** 30, size=(1,)).item()
            obs, _ = env.reset(seed=seed)
            done = False
            score = 0
            while not done:
                action = self.predict(obs, deterministic=False)
                obs, reward, done, timeout, info = env.step(action)
                score += reward
                done = done or timeout
            scores.append(score)
        return np.asarray(scores)

    def save(self, path, overwrite=True):
        self.policy.save_checkpoint(path, overwrite)

    def load(self, path):
        self.policy.load_checkpoint(path)

    def score(self, risk_measure: str, risk_eta: float,
              n_eval: int = 1000, eval_env: Optional[gym.Env] = None):
        if eval_env is None:
            eval_env = self.env
        evaluations = self.evaluate(n_eval, eval_env)
        evaluations.sort()
        if risk_measure == 'cvar':
            return evaluations[:int(len(evaluations) * risk_eta)].mean(), evaluations
        empirical_quantile_fn = interp1d(evaluations, np.linspace(0, 1, len(evaluations)), kind='slinear')
        linspace = np.linspace(0, 1, 10000)
        if risk_measure == 'wang':
            return empirical_quantile_fn(wang(linspace, risk_eta)).mean(), evaluations
        elif risk_measure == 'power' or risk_measure == 'pow':
            return empirical_quantile_fn(pow(linspace, risk_eta)).mean(), evaluations
        else:
            raise NotImplementedError(f"risk measure {risk_measure} is not implemented yet.")

    def train_pipeline(self, path, normalizer: Optional[Callable] = None,
                       train_steps: int = 1000_000, n_eval: int = 1000,

                       ):
        self.train(epoch=1, len_epoch=train_steps, normalizer=normalizer)
        self.save(path, overwrite=True)
        risk_value, total_score = self.score(self.risk_type, self.risk_eta, n_eval=n_eval)
        print(f"{self.risk_type}_{self.risk_eta}:\t{risk_value}, MEAN:\t{np.mean(total_score)}+/-{np.std(total_score)}")

