import time
import os
from collections import defaultdict
import numpy as np
import torch
import gym
import copy
from typing import Optional, Dict, List, Tuple, Union
from tqdm import tqdm
from collections import deque
from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger
from offlinerlkit.utils.util_fns import get_normalized_std_score
from offlinerlkit.policy import BasePolicy, MOPOPolicy
from offlinerlkit.utils.scaler import StandardScaler
import random
import pickle
import faiss

# model-based policy trainer
class MBPolicyTrainer:
    def __init__(
        self,
        args,
        policy: Union[MOPOPolicy],
        eval_env: gym.Env,
        real_buffer: ReplayBuffer,
        fake_buffer: ReplayBuffer,
        logger: Logger,
        rollout_setting: Tuple[int, int, int],
        epoch: int = 1000,
        step_per_epoch: int = 1000,
        batch_size: int = 256,
        real_ratio: float = 0.05,
        eval_episodes: int = 10,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        dynamics_update_freq: int = 0,

    ) -> None:

        self.args = args
        self.policy = policy

        self.eval_env = eval_env
        self.eval_env.reset()
        self.real_buffer = real_buffer
        self.fake_buffer = fake_buffer
        self.logger = logger

        self.task = args.task
        self.num_timesteps = None

        self._rollout_freq, self._rollout_batch_size, self._rollout_length = rollout_setting
        self._dynamics_update_freq = dynamics_update_freq

        self._epoch = epoch
        self._step_per_epoch = step_per_epoch
        self._batch_size = batch_size
        self._real_ratio = real_ratio
        self._eval_episodes = eval_episodes
        self.lr_scheduler = lr_scheduler



    def anchor_seeker_pretrain_reverse(self, load_reverse_imagination_path, n_epoch, batch_size, lr, asp_which, logger, data=None) -> None:
        self.policy.anchor_seeker_pretrain_reverse(load_reverse_imagination_path, n_epoch, batch_size, lr, asp_which, logger, data)
        self.logger.dumpkvs()


    def train(
        self,
        max_epochs_since_update: int = 5,
    ) -> Dict[str, float]:
        start_time = time.time()

        ## anchor seeker freeze
        self.policy.anchor_seeking_actor_freeze()

        self.policy.dynamics.model.eval()
        self.policy.dynamics.model.requires_grad_(False)
        self.num_timesteps = 0
        self.last_10_performance = deque(maxlen=10)

        self.best_epoch, self.best_last10_epoch, self.best_metric, self.best_last10_metric = None, None, None, None

        checkpoint_last = os.path.join(os.path.dirname(self.logger.checkpoint_dir), "checkpoint_last")
        checkpoint_best = os.path.join(os.path.dirname(self.logger.checkpoint_dir), "checkpoint_best")
        checkpoint_best_last10 = os.path.join(os.path.dirname(self.logger.checkpoint_dir), "checkpoint_best_last10")

        # train loop
        for e in range(1, self._epoch + 1):
            train_start_time = time.time()
            self.policy.train()

            pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}")

            for it in pbar:
                if self.args.rollout_augmentation and self._rollout_length > 0:
                    if self.num_timesteps % self._rollout_freq == 0:
                        init_obss = self.real_buffer.sample(self._rollout_batch_size)["observations"]#.cpu().numpy()

                        rollout_transitions, rollout_info = self.policy.rollout(init_obss, self._rollout_length)
                        self.fake_buffer.add_batch(**rollout_transitions)

                        self.logger.log(
                            "num rollout transitions: {}, reward mean: {:.4f}, reward std: {:.4f}".\
                                format(rollout_info["num_transitions"], rollout_info["reward_mean"], rollout_info["reward_std"])
                        )
                        for _key, _value in rollout_info.items():
                            self.logger.logkv_mean("rollout_info/"+_key, _value)

                real_sample_size = int(self._batch_size * self._real_ratio)
                fake_sample_size = self._batch_size - real_sample_size
                real_batch = self.real_buffer.sample(batch_size=real_sample_size)
                if self._rollout_length > 0:
                    fake_batch = self.fake_buffer.sample(batch_size=fake_sample_size)
                    batch = {"real": real_batch, "fake": fake_batch}
                else: # no rollout

                    batch = {"real": real_batch}

                loss = self.policy.learn(batch)

                for k, v in loss.items():
                    self.logger.logkv_mean(k, v)


                self.num_timesteps += 1

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()



            train_time = time.time() -  train_start_time
            self.logger.logkv("time/train_one_epoch", train_time)

            norm_ep_rew_mean = self.evaluate_current_policy_mujoco(epoch=e)

            # # early stopping
            if e % 100 == 0:
                self.fake_buffer.save_replay_buffer(self.logger._result_dir, e, self.logger)
                self.save_checkpoints(checkpoint_last, checkpoint_best, checkpoint_best_last10, e, norm_ep_rew_mean)
                self.fake_buffer.delete_replay_buffer(self.logger._result_dir, e, self.logger)

        self.logger.log("total time: {:.2f}s".format(time.time() - start_time))
        self.policy.dynamics.save(self.logger.model_dir)
        self.logger.close()


    def save_checkpoints(self, checkpoint_last, checkpoint_best, checkpoint_best_last10, epoch, norm_ep_rew_mean):
        # save random state
        random_states = {}
        random_states["random"] = random.getstate()
        random_states["np"] = np.random.get_state() # dictionary
        random_states["torch"] = torch.get_rng_state() # Tensor
        random_states["torch_cuda"] = torch.cuda.get_rng_state_all() # List[Tensor]

        self.policy.save(checkpoint_last, random_states=random_states, epoch=epoch, logger=self.logger, lr_scheduler=self.lr_scheduler, last_10_performance=self.last_10_performance)

    def evaluate_current_policy_mujoco(self, epoch, verbose=False):
        # d4rl benchmark
        norm_ep_rew_mean = self._evaluate_and_log()
        self.last_10_performance.append(norm_ep_rew_mean)
        self.logger.logkv(f"eval/last_10_performance", np.mean(self.last_10_performance))
        self.logger.dumpkvs()

        return norm_ep_rew_mean

    def _evaluate_and_log(self, knock_level=0):
        eval_start_time = time.time()
        eval_info = self._evaluate(knock_level=knock_level)

        prefix = "eval" if knock_level==0 else f"eval_knock{knock_level:02d}"

        ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
        ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])

        norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
        norm_ep_rew_std = get_normalized_std_score(self.eval_env, ep_reward_std)

        self.logger.logkv(f"{prefix}/normalized_episode_reward", norm_ep_rew_mean)
        self.logger.logkv(f"{prefix}/normalized_episode_reward_std", norm_ep_rew_std)
        self.logger.logkv(f"{prefix}/episode_reward", ep_reward_mean)
        self.logger.logkv(f"{prefix}/episode_reward_std", ep_reward_std)
        self.logger.logkv(f"{prefix}/episode_length", ep_length_mean)
        self.logger.logkv(f"{prefix}/episode_length_std", ep_length_std)
        self.logger.logkv(f"{prefix}/D_entropy", eval_info["eval/D_entropy"])
        self.logger.logkv(f"{prefix}/log_p", eval_info["eval/log_p"])
        self.logger.set_timestep(self.num_timesteps)
        eval_time = time.time() - eval_start_time
        self.logger.logkv("time/eval_one_epoch", eval_time)

        return norm_ep_rew_mean

    @torch.no_grad()
    def _evaluate(self, knock_level=0) -> Dict[str, List[float]]:
        self.policy.eval()

        obs = self.eval_env.reset()

        eval_ep_info_buffer = []
        num_episodes = 0
        episode_reward, episode_length = 0, 0
        D_entropys, log_ps= [], []
        while num_episodes < self._eval_episodes:

            extra_info = {}

            action = self.policy.select_action(obs.reshape(1, -1), deterministic=True, extra_info=extra_info)
            D_entropys.append(extra_info['dist'].entropy().mean().item())
            log_ps.append(extra_info['log_prob'].mean().item())
            next_obs, reward, terminal, _ = self.eval_env.step(action.flatten())

            episode_reward += reward
            episode_length += 1

            obs = next_obs

            if terminal or episode_length >= self.eval_env._max_episode_steps:
                eval_ep_info_buffer.append({
                        "episode_reward": episode_reward,
                        "episode_length": episode_length,
                })
                num_episodes +=1
                episode_reward, episode_length = 0, 0
                obs = self.eval_env.reset()

        return {
            "eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
            "eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer],
            "eval/D_entropy": np.mean(D_entropys),
            "eval/log_p": np.mean(log_ps),
        }

    def eval_only(self, eval_epoch=10) -> Dict[str, float]:
        start_time = time.time()
        self.policy.eval()
        self.num_timesteps = 0
        self._evaluate_and_log_mean(eval_epoch=eval_epoch) ### main performance metric
        self.logger.log("total time: {:.2f}s".format(time.time() - start_time))
        self.logger.close()

    def _evaluate_and_log_mean(self, eval_epoch=10):
        eval_start_time = time.time()
        prefix = "eval"
        for e in range(eval_epoch):
            eval_info = self._evaluate()
            ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
            ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])

            norm_ep_rew_mean = self.eval_env.get_normalized_score(ep_reward_mean) * 100
            norm_ep_rew_std = get_normalized_std_score(self.eval_env, ep_reward_std)

            self.logger.logkv_mean(f"{prefix}/normalized_episode_reward", norm_ep_rew_mean)
            self.logger.logkv_mean(f"{prefix}/normalized_episode_reward_std", norm_ep_rew_std)
            self.logger.logkv_mean(f"{prefix}/episode_reward", ep_reward_mean)
            self.logger.logkv_mean(f"{prefix}/episode_reward_std", ep_reward_std)
            self.logger.logkv_mean(f"{prefix}/episode_length", ep_length_mean)
            self.logger.logkv_mean(f"{prefix}/episode_length_std", ep_length_std)

            self.logger.logkv_mean(f"{prefix}/D_entropy", eval_info["eval/D_entropy"])
            self.logger.logkv_mean(f"{prefix}/log_p", eval_info["eval/log_p"])
        self.logger.set_timestep(self.num_timesteps)
        eval_time = time.time() - eval_start_time
        self.logger.logkv("time/eval_one_epoch", eval_time)

        self.logger.dumpkvs(exclude=["dynamics_training_progress"])

        return norm_ep_rew_mean
