import time
import os

import numpy as np
import torch
import gym
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple, Union
from tqdm import tqdm
from collections import deque
from offlinerlkit.dynamics import BaseDynamics, ReverseEnsembleDynamics
from offlinerlkit.buffer import ReplayBuffer
from offlinerlkit.utils.logger import Logger
from offlinerlkit.utils.util_fns import get_normalized_std_score
from offlinerlkit.policy import BasePolicy

import random
from collections import defaultdict
from copy import deepcopy
import ipdb
import math
# model-free policy trainer
class ReversePolicyTrainer:
    def __init__(
        self,
        args,
        reverse_dynamics: ReverseEnsembleDynamics,
        dynamics: BaseDynamics,
        policy: Union[BasePolicy],
        eval_env: gym.Env,
        real_buffer: ReplayBuffer,
        fake_buffer: ReplayBuffer,
        logger: Logger,
        rollout_setting: Tuple[int, int, int],
        epoch: int = 1000,
        step_per_epoch: int = 1000,
        batch_size: int = 256,
        eval_episodes: int = 10,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    ) -> None:

        self.args = args
        self.policy = policy
        self.reverse_dynamics = reverse_dynamics
        self.dynamics = dynamics
        self.eval_env = eval_env
        self.real_buffer = real_buffer
        self.fake_buffer = fake_buffer
        self.logger = logger

        self._rollout_epoch, self._rollout_batch_size, \
            self._rollout_length = rollout_setting

        self._epoch = epoch
        self._step_per_epoch = step_per_epoch
        self._batch_size = batch_size
        self._eval_episodes = eval_episodes
        self.lr_scheduler = lr_scheduler
        self.num_timesteps = None

    def train(
        self,
        max_epochs_since_update: int = 5,
    ) -> Dict[str, float]:
        start_time = time.time()

        self.num_timesteps = 0
        last_10_performance = deque(maxlen=10)
        old_loss = 1e10
        cnt = 0
        best_last10_epoch, best_metric, best_last10_metric = None, None, None

        checkpoint_last = os.path.join(os.path.dirname(self.logger.checkpoint_dir), "checkpoint_last")
        checkpoint_best = os.path.join(os.path.dirname(self.logger.checkpoint_dir), "checkpoint_best")
        checkpoint_best_last10 = os.path.join(os.path.dirname(self.logger.checkpoint_dir), "checkpoint_best_last10")


        # train loop
        for e in range(1, self._epoch + 1):

            self.policy.train()

            pbar = tqdm(range(self._step_per_epoch), desc=f"Epoch #{e}/{self._epoch}")
            for it in pbar:
                batch = self.real_buffer.sample(self._batch_size)
                loss = self.policy.learn(batch)
                # pbar.set_postfix(**loss)

                for k, v in loss.items():
                    self.logger.logkv_mean(k, v)

                self.num_timesteps += 1

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            vae_loss = loss["loss/vae"].item()
            KL_loss = loss["loss/KL"].item()
            recon_loss = loss["loss/recon"].item()
            step_loss = loss["loss/step"].item()

            # early stopping
            new_loss = vae_loss
            improvement = (old_loss - new_loss) / old_loss
            old_loss = new_loss
            if abs(improvement) > 0.001:
                cnt = 0
            else:
                cnt += 1
            if cnt >= max_epochs_since_update:
                break

            # save random state
            random_states = {}
            random_states["random"] = random.getstate()
            random_states["np"] = np.random.get_state() # dictionary
            random_states["torch"] = torch.get_rng_state() # Tensor
            random_states["torch_cuda"] = torch.cuda.get_rng_state_all() # List[Tensor]
            random_states["eval_envs"] = self.eval_env.np_random.bit_generator.state

            # save checkpoint
            if e % 10 == 0:
                self.policy.save(checkpoint_last, random_states=random_states)

            # save best checkpoint
            if best_metric is None or vae_loss < best_metric:
                best_metric = vae_loss
                best_epoch = e
                self.policy.save(checkpoint_best, random_states=random_states)

            # save best_last10 checkpoint
            if len(last_10_performance) == 10:
                if best_last10_metric is None or np.mean(last_10_performance) > best_last10_metric:
                    best_last10_metric = np.mean(last_10_performance)
                    best_last10_epoch = e
                    self.policy.save(checkpoint_best_last10, random_states=random_states)

            self._evaluate()
            # save checkpoint
            torch.save(self.policy.state_dict(), os.path.join(self.logger.checkpoint_dir, "policy.pth"))

        self.logger.log("total time: {:.2f}s".format(time.time() - start_time))
        torch.save(self.policy.state_dict(), os.path.join(self.logger.model_dir, "policy.pth"))

        return {"last_10_performance": np.mean(last_10_performance)}

    @torch.no_grad()
    def _evaluate(self) -> None:
        batch = self.real_buffer.sample(self._batch_size * 100)
        for eval_rollout_length in [1,3,5]:
            loss = 0

            init_obss = batch["observations"]
            rollout_transitions, _ = self.policy.rollout(init_obss, eval_rollout_length)

            anchors, _, _, _ = self.dynamics.step(rollout_transitions['obss'], rollout_transitions['actions'])

            loss = F.mse_loss(anchors, rollout_transitions['next_obss']) / eval_rollout_length


    def generate(self) -> None:
        self.fake_buffer.reset()
        dataset = self.real_buffer.sample_all()
        init_obss = torch.tensor(dataset['next_observations'], device=self.args.device)
        prev_actions = torch.tensor(dataset['actions'], device=self.args.device)
        assert len(init_obss) == self.real_buffer._max_size, f'len(init_obss): {len(init_obss)}, self.real_buffer._max_size: {self.real_buffer._max_size}'

        for i in tqdm(range(math.ceil(self.real_buffer._max_size/self._batch_size))):
            init_obs = init_obss[i*self._batch_size : (i+1)*self._batch_size]
            prev_action = prev_actions[i*self._batch_size : (i+1)*self._batch_size]
            rollout_transitions, rollout_info = self.policy.rollout(init_obs, self._rollout_length, prev_action)
            self.fake_buffer.add_batch(**rollout_transitions)
            self.logger.log(
                "num rollout transitions: {}, reward mean: {:.4f}, reward std: {:.4f}".\
                    format(rollout_info["num_transitions"], rollout_info["reward_mean"], rollout_info["reward_std"])
            )
            for _key, _value in rollout_info.items():
                self.logger.logkv_mean("rollout_info/"+_key, _value)

        self.fake_buffer.save(path=self.logger.result_dir, data=self.fake_buffer.sample_all())
        self.logger.log(f'Fake buffer saved successfully in {self.logger.result_dir}')
