import gymnasium
import gymnasium as gym
import pandas as pd
from gymnasium.core import WrapperObsType

from risk_morl.buffer import MultiObjectiveReplayBuffer
from time import time
from datetime import timedelta
import numpy as np
from typing import Optional, Callable, Any
from tqdm import trange
from stable_baselines3.common.vec_env import VecEnv
from risk_morl.offpolicy import RiskSensitiveOffPolicyRLJax
from collections import deque
import os
from risk_morl.dqn_policy.policy import MODQNPolicy
from basics.epsilon_greedy_scheduler import LinearScheduler
from flax import nnx


class MODQN(RiskSensitiveOffPolicyRLJax):
    policy: MODQNPolicy
    buffer: MultiObjectiveReplayBuffer
    scheduler: LinearScheduler

    def __init__(self,
                 env: gym.Env | VecEnv,
                 test_env: gym.Env,
                 gamma: float = 0.99,
                 batch_size: int = 256,
                 buffer_size: int = int(3e+5),
                 lr: float = 3e-4,
                 risk_measure: Callable = lambda x: 0.1 * x,
                 policy_kwargs: Optional[dict] = None,
                 custom_weights: Optional = None,
                 seed: int = 42
                 ):
        if policy_kwargs is None:
            policy_kwargs = { }
        policy_kwargs["risk_measure"] = risk_measure
        if custom_weights is not None:
            policy_kwargs['custom_weights'] = custom_weights

        super().__init__(env=env,
                         gamma=gamma, batch_size=batch_size,
                         buffer_size=buffer_size,
                         critic_lr=lr,
                         policy_kwargs=policy_kwargs,
                         seed=seed)
        self.test_env = test_env
        self.test_state = np.zeros(shape=(self.reward_dim,))
        x = np.zeros(shape=(self.reward_dim,))
        x[0] = 1.
        self.bias = deque([x.copy()] * 100)
        self.utility_deque = deque(maxlen=100)

    def build_buffer(self):
        self.buffer = MultiObjectiveReplayBuffer(
            observation_space=self.env.observation_space,
            action_space=self.env.action_space,
            n_envs=self.env.num_envs,
            num_reward=self.reward_dim,
            buffer_size=self.buffer_size,
            seed=self.seed
        )

    def build_policy(self):
        self.policy_kwargs.pop('actor_lr')
        self.policy = MODQNPolicy(
            self.env,
            reward_dim=self.reward_dim,
            **self.policy_kwargs
        )

    def predict(self, observation, *, state: Optional = None, deterministic: bool = True) -> np.ndarray:
        return self.policy.predict(observation, state, )

    def get_state(self):
        return np.asarray([self.random_weight() for _ in range(self.env.num_envs)])

    def set_state(self, index):
        self.state[index] = self.policy.sample_preference(1).squeeze()

    def train_step(self):
        return self.policy.train_step(self.buffer.sample(self.batch_size))

    def learn(self, n_steps: int,
              log_interval: int = 4,
              train_frequency: int = 1,
              eps_greedy_fraction: float = 0.1,
              min_prob: float = 0.05,
              max_prob: float = 1.,
              n_train: int = 1,
              learning_start: int = 100,
              test_interval: int = int(1e+5),
              test_env: Optional[gym.Env] = None,
              episodic_learn: bool = False,
              need_pretrain: bool = False,
              exploration_fraction: float = 0.1,
              begin_epsilon_greedy: float = 1.,
              end_epsilon_greedy: float = 0.,
              ):
        scheduler = LinearScheduler(
            int(n_steps * eps_greedy_fraction),
            min_prob=min_prob,
            max_prob=max_prob,
            rngs=nnx.Rngs(self.seed))

        last_obs = self.env.reset()
        score = np.zeros((self.env.num_envs, self.reward_dim))
        epicnt = 0.
        step_cnt = np.zeros(self.env.num_envs, dtype=np.int32)
        at_least_one_train: bool = False
        need_pretrain = need_pretrain
        train_start_time = time()
        for s in range(n_steps):
            start_time = time()

            if s < learning_start:
                action = np.asarray([self.env.action_space.sample() for _ in range(self.env.num_envs)])
                _, self.state = self.predict(last_obs, state=self.state, deterministic=False)
            else:
                # epsilon greedy
                rand_action, prob = scheduler()
                self.logger.record('train/eps_prob', 1 - prob.item())
                if bool(rand_action):
                    action = np.asarray([self.env.action_space.sample() for _ in range(self.env.num_envs)])
                else:
                    action, self.state = self.predict(last_obs, state=self.state, deterministic=False)
                if s % train_frequency == 0 and self.buffer.size() > 5 and not episodic_learn:
                    it = range(n_train) if n_train < 100 else trange(n_train)
                    at_least_one_train = True
                    for _ in it:
                        self.train_step()

            if s % test_interval == 0 and s > 0 and test_env is not None:
                self.test(test_env, 10)

            next_obs, reward, done, info = self.env.step(action)

            self.buffer.add(
                obs=last_obs.copy(), next_obs=next_obs, action=action, reward=reward, done=done, infos=info,
            )
            score = score + reward
            last_obs = next_obs.copy()
            step_cnt += 1
            end_time = time()
            elapsed_time = time()
            time_spent = elapsed_time - train_start_time
            fps = self.env.num_envs * s / time_spent
            self.logger.record("Time/fps", fps)
            self.logger.record("Time/elapsed", str(timedelta(seconds=int(time_spent))))

            v = (s + 1) / time_spent
            remaining_frames = n_steps - s
            eta_seconds = remaining_frames / v
            eta = timedelta(seconds=int(eta_seconds))
            self.logger.record("Time/eta", str(eta))

            if done.any():
                index = np.where(done)[0]
                s_ = []
                for i in index:
                    epicnt += 1
                    s_.append(score[i].copy())
                    self.score_deque.append(score[i].copy())
                    self.step_deque.append(step_cnt[i])
                    utility = (self.state[i] * score[i]).sum()
                    self.utility_deque.append(utility)

                    if episodic_learn and s > learning_start:
                        if need_pretrain:
                            for _ in trange(learning_start * 10):
                                self.pretrain_step()
                            need_pretrain = False

                        learnings = int(n_train) * int(step_cnt[i])
                        it = trange(learnings)
                        at_least_one_train = True
                        for _ in it:
                            self.train_step()

                    self.logger.record(key='Episode/num_epi', value=epicnt)
                    self.logger.record(key='Episode/epilen', value=step_cnt[i])
                    self.logger.record(key='Episode/utility', value=utility)
                    self.logger.record(key='Episode/utility-MA100', value=np.mean(self.utility_deque))

                    score[i] = 0
                    step_cnt[i] = 0
                    self.set_state(i)

                s_arr = np.asarray(s_)  # (num_done, reward_dim)
                mean_scores = s_arr.mean(axis=0)
                if len(self.score_deque) > 0:
                    deque_mean = np.mean(np.stack(self.score_deque, axis=0), axis=0)
                else:
                    deque_mean = np.zeros(self.reward_dim, dtype=np.float32)

                for r_dim in range(self.reward_dim):
                    self.logger.record(key=f"Episode/score_{r_dim}", value=float(mean_scores[r_dim]))
                    self.logger.record(key=f"Episode/mean_score_{r_dim}", value=float(deque_mean[r_dim]))

                self.logger.record(key='Episode/mean_epilen', value=float(np.mean(self.step_deque)))
                self.logger.record(key='Train/current_step', value=s * self.env.num_envs)

                if at_least_one_train:
                    train_log = self.get_train_log()
                    for k, v in train_log.items():
                        self.logger.record(key=f"Train/{k}", value=v)
                    at_least_one_train = False
                if epicnt % log_interval == 0:
                    self.logger.dump()

    def get_train_log(self) -> dict:
        return self.policy.log_all()

    def random_weight(self, ):
        w = self.np_rng.dirichlet(alpha=np.ones(self.reward_dim, ) * 0.1)

        return w

    def save(self, path, dir: Optional[str] = None):
        if dir is not None:
            os.makedirs(dir, exist_ok=True)
            path = os.path.join(dir, path)
        self.policy.save(path)

    def load(self, path, dir: Optional[str] = None):
        if dir is not None:
            path = os.path.join(dir, path)
        self.policy.load(path)


if __name__ == '__main__':
    from gymnasium.wrappers import RescaleAction, TransformReward

    import mo_gymnasium
    from risk_morl.utils.mo_utils import test, df_to_points, hypervolume_for_preferences, hypervolume
    from risk_morl.utils.env_util import SafetyGymnasiumMO, Hopper2d
    from functools import partial
    from mo_gymnasium.envs.minecart import minecart
    from mo_gymnasium.envs.mujoco import hopper_v5, reacher_v5
    from gymnasium.wrappers import TransformObservation

    from sbx import DQN


    class FruitTreeWrapper(gymnasium.ObservationWrapper):
        def __init__(self):
            env = mo_gymnasium.make("fruit-tree-v0")
            super().__init__(env)
            self.reward_dim = 6
            self.observation_space = gymnasium.spaces.Box(low=np.asarray([0] * 128),
                                                          high=np.asarray([1] * 128),
                                                          shape=(128,),
                                                          dtype=np.float32)
            self.action_space = env.action_space
            self.eye = np.eye(64)

        def observation(self, obs):
            o_1, o_2 = obs[0], obs[1]
            o_1: int
            o_2: int
            return np.concatenate([self.eye[o_1], self.eye[o_2]], axis=-1).copy()


    env = FruitTreeWrapper()
    test_env = FruitTreeWrapper()

    import jax


    def risk_measure(x):
        return x


    custom_weights = [np.asarray([1., 0., 0.]), np.asarray([0, 1, 0]), np.asarray([0, 0, 1]),
                      np.asarray([1 / 3, 1 / 3, 1 / 3]), np.asarray([2 / 3, 1 / 3, 0]), np.asarray([1 / 3, 2 / 3, 0]),
                      np.asarray([1 / 3, 0, 2 / 3]), np.asarray([2 / 3, 0., 1 / 3]), np.asarray([0, 1 / 3, 2 / 3]),
                      np.asarray([0, 2 / 3, 1 / 3]),
                      ]
    for i in range(5):
        model = MODQN(env=env, test_env=test_env, batch_size=32, gamma=0.99,
                      risk_measure=jax.jit(risk_measure), lr=3e-4, policy_kwargs={ "soft_update_ratio": 5e-3 },
                      seed=i, buffer_size=int(1e+4)
                      )

        model.learn(int(1e+5), log_interval=1, train_frequency=1,
                    max_prob=1, min_prob=0.1, learning_start=100, eps_greedy_fraction=0.5, )

        model.save("fruit_tree.pkl")
        from utils.mo_utils import simplex_grid_points
        from utils.mo_utils import test
        df = test(env, model, 4, 6, 'FTN')
        df.to_csv(f"FTN_{i}.csv")



