import numpy as np
from stable_baselines3.common.vec_env import VecEnv
import gymnasium as gym
from typing import Optional, List
from risk_morl.utils.logger import WandbOutputFormat
from stable_baselines3.common.logger import Logger, HumanOutputFormat, KVWriter
import sys
from datetime import timedelta
from time import time
from collections import deque
from tqdm import trange
from risk_morl.buffer import MultiObjectiveReplayBuffer
from risk_morl.utils.env_util import reward_dim, MODummyVecEnv


class RiskSensitiveOffPolicyRLJax(object):
    def __init__(self,
                 env: gym.Env | VecEnv,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 buffer_size: int = int(1e+6),
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,
                 policy_kwargs: Optional[dict] = None,
                 wandb_kwargs: Optional[dict] = None,
                 *,
                 seed: int = 42
                 ):

        self.reward_dim = reward_dim(env)
        if not isinstance(env, VecEnv):
            print("wrap with dummy vec env")
            env = MODummyVecEnv([lambda: env])
        self.env = env
        self.gamma = gamma
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.seed = seed
        if policy_kwargs is None:
            policy_kwargs = {}
        self.actor_lr = actor_lr
        policy_kwargs.update(critic_lr=critic_lr, actor_lr=actor_lr, gamma=gamma, seed=seed)
        self.policy_kwargs = policy_kwargs

        self.score_deque = deque(maxlen=100)
        self.cost_deque = deque(maxlen=100)
        self.step_deque = deque(maxlen=100)
        self.build_policy()
        output_format: List[KVWriter] = [HumanOutputFormat(sys.stdout)]
        if wandb_kwargs is not None:
            output_format.append(WandbOutputFormat(**wandb_kwargs))
        self.logger = Logger(folder=None, output_formats=output_format)
        self.np_rng = np.random.default_rng(seed)
        self.state: Optional = self.get_state()

        self.build_buffer()

    def build_buffer(self):
        self.buffer = MultiObjectiveReplayBuffer(
            observation_space=self.env.observation_space,
            action_space=self.env.action_space,
            n_envs=self.env.num_envs,
            num_reward=self.reward_dim,
            buffer_size=self.buffer_size,
            seed=self.seed
        )

    def get_state(self):
        return None

    def build_policy(self):
        pass

    def predict(self, observation, *, state: Optional = None, deterministic: bool = True) -> np.ndarray:
        pass

    def get_train_log(self) -> dict:
        pass

    def train_step(self):
        pass

    def pretrain_step(self):
        pass

    def set_state(self, index):
        pass

    def learn(self, n_steps: int,
              log_interval: int = 4,
              train_frequency: int = 1,
              n_train: int = 1,
              learning_start: int = 100,
              test_interval: int = int(1e+5),
              test_env: Optional[gym.Env] = None,
              episodic_learn: bool = False,
              need_pretrain: bool = False,
              ):
        last_obs = self.env.reset()
        score = np.zeros((self.env.num_envs, self.reward_dim), dtype=np.float32)
        epicnt = 0.
        step_cnt = np.zeros(self.env.num_envs, dtype=np.int32)
        at_least_one_train: bool = False
        need_pretrain = need_pretrain
        train_start_time = time()
        for s in range(n_steps):
            start_time = time()
            if s < learning_start:
                action = np.asarray([self.env.action_space.sample() for _ in range(self.env.num_envs)])
            else:
                action = self.predict(last_obs, deterministic=False)
                if s % train_frequency == 0 and not episodic_learn:
                    it = range(n_train) if n_train < 100 else trange(n_train)
                    at_least_one_train = True
                    for _ in it:
                        self.train_step()

            if s % test_interval == 0 and s > 0 and test_env is not None:
                self.test(test_env, 10)

            next_obs, reward, done, info = self.env.step(action)
            self.buffer.add(
                obs=last_obs.copy(), next_obs=next_obs, action=action, reward=reward, done=done, infos=info,
            )
            score = score + reward
            last_obs = next_obs.copy()
            step_cnt += 1
            end_time = time()
            elapsed_time = time()
            fps = self.env.num_envs / (end_time - start_time)
            time_spent = elapsed_time - train_start_time
            self.logger.record_mean("Time/fps", fps)
            self.logger.record("Time/elapsed", str(timedelta(seconds=int(time_spent))))
            v = (s + 1) / time_spent
            remaining_frames = n_steps - s
            eta_seconds = remaining_frames / v
            eta = timedelta(seconds=int(eta_seconds))
            self.logger.record("Time/eta", str(eta))

            if done.any():
                index = np.where(done)[0]
                s_ = []
                for i in index:
                    epicnt += 1
                    s_.append(score[i].copy())
                    self.score_deque.append(score[i].copy())
                    self.step_deque.append(step_cnt[i])
                    if episodic_learn and s > learning_start:
                        if need_pretrain:
                            for _ in trange(learning_start * 10):
                                self.pretrain_step()
                            need_pretrain = False
                        learnings = int(n_train) * int(step_cnt[i])
                        it = trange(learnings)
                        at_least_one_train = True
                        for _ in it:
                            self.train_step()
                    self.logger.record(key='Episode/num_epi', value=epicnt)
                    self.logger.record(key='Episode/epilen', value=step_cnt[i])
                    score[i] = 0
                    step_cnt[i] = 0
                    self.set_state(i)

                if len(s_) > 0:
                    s_arr = np.asarray(s_)  # (num_done, reward_dim)
                    mean_scores = s_arr.mean(axis=0)
                    if len(self.score_deque) > 0:
                        deque_mean = np.asarray(self.score_deque).mean(axis=0)
                    else:
                        deque_mean = np.zeros(self.reward_dim, dtype=np.float32)
                    for r_dim in range(self.reward_dim):
                        self.logger.record(key=f"Episode/score_{r_dim}", value=float(mean_scores[r_dim]))
                        self.logger.record(key=f"Episode/mean_score_{r_dim}", value=float(deque_mean[r_dim]))

                self.logger.record(key='Episode/mean_epilen', value=np.mean(self.step_deque))
                self.logger.record(key='Train/current_step', value=s * self.env.num_envs)

                if at_least_one_train:
                    train_log = self.get_train_log()
                    for k, v in train_log.items():
                        self.logger.record(key=f"Train/{k}", value=v)
                    at_least_one_train = False
                if epicnt % log_interval == 0:
                    self.logger.dump()

    def test(self, test_env, n_test):
        scores = []
        for _ in range(n_test):
            obs, _ = test_env.reset()
            done = False
            score = 0
            while not done:
                action = self.predict(obs)
                obs, reward, done, timeout, info = test_env.step(action)
                score += reward
                done = done or timeout
            scores.append(score)
        scores.sort()
        scores = np.asarray(scores)
        print(f"{np.mean(scores):.4f}+/-{np.std(scores)}. CV@R50%:{(scores[:len(scores) // 2]).mean()}")
