import gymnasium as gym
import pandas as pd

from risk_morl.buffer import MOReplayBufferWeight
from time import time
from datetime import timedelta
import numpy as np
from typing import Optional, Callable, Literal
from tqdm import trange
from stable_baselines3.common.vec_env import VecEnv
from risk_morl.offpolicy import RiskSensitiveOffPolicyRLJax
from risk_morl.kr_iqn_sac.policy import KRCriticPolicy, MarginalIQNPolicy, AblationKRCriticPolicy, AblationKRCriticPositionalEncodingPolicy
from risk_morl.ewps.ewp_policy import EWPPolicy

from collections import deque
import os


class MOSAC(RiskSensitiveOffPolicyRLJax):
    policy: KRCriticPolicy
    buffer: MOReplayBufferWeight

    def __init__(self,
                 env: gym.Env | VecEnv,
                 test_env: gym.Env,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 buffer_size: int = int(1e+6),
                 critic_lr: float = 3e-4,
                 actor_lr: float = 3e-4,

                 risk_measure: Callable = lambda x: 0.1 * x,
                 policy_kwargs: Optional[dict] = None,
                 ent_coef: float | Literal['auto'] = 'auto',
                 target_entropy: float | Literal['auto'] = 'auto',
                 actor_marginal_risk: bool = False,
                 comonotone: bool = False,
                 seed: int = 42
                 ):
        if policy_kwargs is None:
            policy_kwargs = { }
        policy_kwargs["risk_measure"] = risk_measure
        policy_kwargs['target_entropy'] = target_entropy
        policy_kwargs['ent_coef'] = ent_coef
        policy_kwargs['actor_marginal_risk'] = actor_marginal_risk
        policy_kwargs['comonotone'] = comonotone

        super().__init__(env=env,
                         gamma=gamma, batch_size=batch_size,
                         buffer_size=buffer_size,
                         actor_lr=actor_lr,
                         critic_lr=critic_lr,
                         policy_kwargs=policy_kwargs,
                         seed=seed)
        self.test_env = test_env
        self.test_state = np.zeros(shape=(self.reward_dim,))
        x = np.zeros(shape=(self.reward_dim,))
        x[0] = 1.
        self.bias = deque([x.copy()] * 100)

    def build_buffer(self):
        self.buffer = MOReplayBufferWeight(
            observation_space=self.env.observation_space,
            action_space=self.env.action_space,
            n_envs=self.env.num_envs,
            num_reward=self.reward_dim,
            buffer_size=self.buffer_size,
            seed=self.seed
        )

    def build_policy(self):
        self.policy = KRCriticPolicy(
            self.env,
            reward_dim=self.reward_dim,
            **self.policy_kwargs
        )

    def predict(self, observation, *, state: Optional = None, deterministic: bool = True,
                langevin: bool = False) -> np.ndarray:
        return self.policy.predict(observation, state, langevin=langevin, deterministic=deterministic)

    def get_state(self):
        return np.asarray([self.policy.sample_preference(1).squeeze() for _ in range(self.env.num_envs)])

    def set_state(self, index):
        self.state[index] = self.policy.sample_preference(1).squeeze()

    def train_step(self):
        return self.policy.train_step(self.buffer.sample(self.batch_size))

    def learn(self, n_steps: int,
              log_interval: int = 4,
              train_frequency: int = 1,
              n_train: int = 1,
              learning_start: int = 100,
              test_interval: int = int(1e+5),
              test_env: Optional[gym.Env] = None,
              episodic_learn: bool = False,
              need_pretrain: bool = False,
              exploration_fraction: float = 0.1,
              begin_epsilon_greedy: float = 1.,
              end_epsilon_greedy: float = 0.,
              ):

        last_obs = self.env.reset()
        score = np.zeros((self.env.num_envs, self.reward_dim))
        epicnt = 0.
        step_cnt = np.zeros(self.env.num_envs, dtype=np.int32)
        at_least_one_train: bool = False
        need_pretrain = need_pretrain
        train_start_time = time()
        for s in range(n_steps):
            start_time = time()

            if s < learning_start:
                action = np.asarray([self.env.action_space.sample() for _ in range(self.env.num_envs)])
                _, self.state = self.predict(last_obs, state=self.state, deterministic=False)
            else:
                action, self.state = self.predict(last_obs, state=self.state, deterministic=False)
                if s % train_frequency == 0 and self.buffer.size() > 5 and not episodic_learn:
                    it = range(n_train) if n_train < 100 else trange(n_train)
                    at_least_one_train = True
                    for _ in it:
                        self.train_step()

            if s % test_interval == 0 and s > 0 and test_env is not None:
                self.test(test_env, 10)

            next_obs, reward, done, info = self.env.step(action)

            self.buffer.add(
                obs=last_obs.copy(), next_obs=next_obs.copy(),
                action=action, reward=reward.copy(), done=done, infos=info,
                weight=np.asarray(self.state.copy())
            )
            score = score + reward
            last_obs = next_obs.copy()
            step_cnt += 1
            end_time = time()
            elapsed_time = time()
            time_spent = elapsed_time - train_start_time
            fps = self.env.num_envs * s / time_spent
            self.logger.record("Time/fps", fps)
            self.logger.record("Time/elapsed", str(timedelta(seconds=int(time_spent))))
            v = (s + 1) / time_spent
            remaining_frames = n_steps - s
            eta_seconds = remaining_frames / v
            eta = timedelta(seconds=int(eta_seconds))
            self.logger.record("Time/eta", str(eta))

            if done.any():
                index = np.where(done)[0]
                s_ = []
                for i in index:
                    epicnt += 1
                    s_.append(score[i].copy())
                    self.score_deque.append(score[i].copy())
                    self.step_deque.append(step_cnt[i])
                    if episodic_learn and s > learning_start:
                        if need_pretrain:
                            for _ in trange(learning_start * 10):
                                self.pretrain_step()
                            need_pretrain = False

                        learnings = int(n_train) * int(step_cnt[i])
                        it = trange(learnings)
                        at_least_one_train = True
                        for _ in it:
                            self.train_step()
                    self.logger.record(key='Episode/num_epi', value=epicnt)
                    self.logger.record(key='Episode/epilen', value=step_cnt[i])
                    score[i] = 0
                    step_cnt[i] = 0
                    self.set_state(i)

                s_arr = np.asarray(s_)  # (num_done, reward_dim)
                mean_scores = s_arr.mean(axis=0)
                if len(self.score_deque) > 0:
                    deque_mean = np.mean(np.stack(self.score_deque, axis=0), axis=0)
                else:
                    deque_mean = np.zeros(self.reward_dim, dtype=np.float32)

                for r_dim in range(self.reward_dim):
                    self.logger.record(key=f"Episode/score_{r_dim}", value=float(mean_scores[r_dim]))
                    self.logger.record(key=f"Episode/mean_score_{r_dim}", value=float(deque_mean[r_dim]))

                self.logger.record(key='Episode/mean_epilen', value=float(np.mean(self.step_deque)))
                self.logger.record(key='Train/current_step', value=s * self.env.num_envs)

                if at_least_one_train:
                    train_log = self.get_train_log()
                    for k, v in train_log.items():
                        self.logger.record(key=f"Train/{k}", value=v)
                    at_least_one_train = False
                if epicnt % log_interval == 0:
                    self.logger.dump()

    def get_train_log(self) -> dict:
        return self.policy.log_all()

    def save(self, path, dir: Optional[str] = None):
        if dir is not None:
            os.makedirs(dir, exist_ok=True)
            path = os.path.join(dir, path)
        self.policy.save(path)

    def load(self, path, dir: Optional[str] = None):
        if dir is not None:
            path = os.path.join(dir, path)
        self.policy.load(path)
        return self

    def test(self, test_env, n_test):
        scores = []
        state = self.state[0].copy()
        print("weight", state)
        state = state[None]

        for _ in range(n_test):
            obs, _ = test_env.reset()
            done = False
            score = np.zeros(shape=self.reward_dim)
            while not done:
                action, _ = self.predict(obs, state=state)
                action = action.reshape(test_env.action_space.shape)

                obs, reward, done, timeout, info = test_env.step(action)
                score += reward
                done = done or timeout
            scores.append(score)
        scores = np.asarray(scores)
        print(f"{np.mean(scores, axis=0)} +/- {np.std(scores, axis=0)}")


class EWPSAC(MOSAC):
    policy: EWPPolicy

    def build_policy(self):
        if 'actor_marginal_risk' in self.policy_kwargs:
            del self.policy_kwargs['actor_marginal_risk']
        if 'comonotone' in self.policy_kwargs:
            del self.policy_kwargs['comonotone']
        self.policy = EWPPolicy(
            self.env,
            reward_dim=self.reward_dim,
            **self.policy_kwargs
        )


class MOSACAblation(MOSAC):
    policy: AblationKRCriticPolicy
    buffer: MOReplayBufferWeight

    def build_policy(self):
        self.policy = AblationKRCriticPolicy(
            self.env,
            reward_dim=self.reward_dim,
            **self.policy_kwargs
        )

class MOSACAblationPositionalEncoding(MOSAC):
    policy: AblationKRCriticPositionalEncodingPolicy
    buffer: MOReplayBufferWeight

    def build_policy(self):
        self.policy = AblationKRCriticPositionalEncodingPolicy(
            self.env,
            reward_dim=self.reward_dim,
            **self.policy_kwargs
        )


class MarginalMOSAC(MOSAC):
    def build_policy(self):
        self.policy = MarginalIQNPolicy(
            self.env,
            reward_dim=self.reward_dim,
            **self.policy_kwargs
        )


if __name__ == '__main__':
    from gymnasium.wrappers import RescaleAction, TransformObservation
    import mo_gymnasium
    from stable_baselines3.common.vec_env import DummyVecEnv

    from risk_morl.utils.mo_utils import test, df_to_points, hypervolume_for_preferences, hypervolume
    from risk_morl.utils.env_util import SafetyGymnasiumMO, Hopper2d, MODummyVecEnv

    from utils.misc import fix_seed

    from functools import partial
    from mo_gymnasium.envs.mujoco import hopper_v5

    # SafetyGymnasiumMO('SafetyPointGoal1-v0',  scale_reward=1., max_episode_steps=1000)  # RescaleAction(mo_gymnasium.make("mo-hopper-v5", max_episode_steps=500),  min_action=-1., max_action=1.)
    # SafetyGymnasiumMO('SafetyPointGoal1-v0')
    import jax
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    def risk_measure(x):
        x = x.at[..., 1, :].set(x[..., 1, :] * 0.5)
        return x


    def triangle_measure(x):
        u = x[..., 0, :]
        v = x[..., 1, :]
        return 0.5 * jax.numpy.stack([1 - jax.numpy.sqrt(u), v * jax.numpy.sqrt(u)], axis=-2)


    def hopper_measure(x):
        u = x[..., 0, :]
        v = x[..., 1, :]
        t = 0.5 * jax.numpy.stack([1 - jax.numpy.sqrt(u), v * jax.numpy.sqrt(u), x[..., 2, :]], axis=-2)
        bias = jax.numpy.zeros_like(t)
        bias = bias.at[..., 2, :].set(0.5 * jax.numpy.ones_like(bias[..., 2, :]))
        return t + bias


    def neutral_measure(x):
        return x


    def simplex_measure(x):
        def f(x):
            return 1 - jax.numpy.power(jax.numpy.sqrt(1 - x), 1 / 3)

        u1 = f(x[..., 0, :])
        u2 = (1 - u1) * f(x[..., 1, :])
        u3 = (1 - u1 - u2) * f(x[..., 2, :])
        return 0.5 * jax.numpy.stack([u1, u2, u3], axis=-2)


    for i in range(5):
        fix_seed(i)

        test_env = RescaleAction(mo_gymnasium.make("mo-halfcheetah-v5", max_episode_steps=500),
                                 min_action=-1., max_action=1.)
        env = MODummyVecEnv([lambda: RescaleAction(mo_gymnasium.make("mo-halfcheetah-v5", max_episode_steps=500),
                                                   min_action=-1., max_action=1.)
                             for i in range(1)])

        model = MarginalMOSAC(env=env, test_env=test_env, batch_size=256,
                              risk_measure=jax.jit(neutral_measure), critic_lr=3e-4, gamma=0.99,
                              actor_marginal_risk=False, comonotone=False,
                              actor_lr=3e-4, seed=i, policy_kwargs={ "soft_update_ratio": 5e-3,
                                                                     "truncation_lower": 2,
                                                                     "truncation_upper": 2
                                                                     }
                              )

        model.learn(int(1e+6), log_interval=1, train_frequency=1)
        model.save(f"hopper{i}_point.pkl")

        result = test(test_env, model, 6, 3, 'mo_sac', 100)
        result.to_csv(f'hopper{i}_point.csv')
