import os
import numpy as np
import torch
import torch.nn as nn
import gym
from glob import glob
import re
from tqdm import tqdm
import random
from torch.nn import functional as F
from typing import Dict, Union, Tuple, Optional, List
from collections import defaultdict
from offlinerlkit.utils.scaler import StandardScaler
from offlinerlkit.policy import SACPolicy
from offlinerlkit.dynamics import ReverseEnsembleDynamics, EnsembleDynamics
from offlinerlkit.modules.anchor_seeker import HeuristicAnchorSeeker
from copy import deepcopy
import wandb
import pickle
import math
from offlinerlkit.utils.logger import Logger
from collections import deque

class MOPOPolicy(SACPolicy):
    def __init__(
        self,
        args,
        dynamics: EnsembleDynamics,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        tau: float = 0.005,
        gamma: float  = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
        device: str = "cpu",
    ) -> None:
        super().__init__(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            alpha=alpha,
        )
        self.args = args
        self.dynamics = dynamics
        self.device = torch.device(device)
        self.anchor_sharing = True

    @torch.no_grad()
    def rollout(
        self,
        init_obss: Union[np.ndarray, torch.Tensor],
        rollout_length: int
    ) -> Tuple[Dict[str, np.ndarray], Dict]:
        # numpy version
        if type(init_obss) == np.ndarray:
            num_transitions = 0
            rewards_arr = np.array([])
            rollout_transitions = defaultdict(list)

            # rollout
            observations = init_obss
            for _ in range(rollout_length):
                actions = super().select_action(observations)
                with torch.no_grad():
                    next_observations, rewards, terminals, info = self.dynamics.step(observations, actions)
                rollout_transitions["obss"].append(observations)
                rollout_transitions["next_obss"].append(next_observations)
                rollout_transitions["actions"].append(actions)
                rollout_transitions["rewards"].append(rewards)
                rollout_transitions["terminals"].append(terminals)
                num_transitions += len(observations)
                rewards_arr = np.append(rewards_arr, rewards.flatten())
                nonterm_mask = (~terminals).flatten()
                if nonterm_mask.sum() == 0:
                    break
                observations = next_observations[nonterm_mask]
            for k, v in rollout_transitions.items():
                rollout_transitions[k] = np.concatenate(v, axis=0)

            return rollout_transitions, \
                {"num_transitions": num_transitions, "reward_mean": rewards_arr.mean(), "reward_std": rewards_arr.std()}
        # tensor version
        else:
            num_transitions = 0
            rewards_arr = torch.tensor([], device=self.device)
            rollout_transitions = defaultdict(list)

            # rollout
            observations = init_obss
            for _ in range(rollout_length):

                actions, _ = super().actforward(observations)

                next_observations, rewards, terminals, info = self.dynamics.step(observations, actions)
                rollout_transitions["obss"].append(observations)
                rollout_transitions["next_obss"].append(next_observations)
                rollout_transitions["actions"].append(actions)
                rollout_transitions["rewards"].append(rewards)
                rollout_transitions["terminals"].append(terminals)

                num_transitions += len(observations)
                rewards_arr = torch.cat((rewards_arr, rewards.reshape(-1)), dim=0)

                nonterm_mask = (~terminals).reshape(-1)
                if nonterm_mask.sum() == 0:
                    break
                observations = next_observations[nonterm_mask]

            for k, v in rollout_transitions.items():
                rollout_transitions[k] = torch.cat(v, dim=0)

            return rollout_transitions, \
                {"num_transitions": num_transitions, "reward_mean": rewards_arr.mean().item(), "reward_std": rewards_arr.std().item()}

    def select_action(
        self,
        obs: np.ndarray,
        deterministic: bool = False,
        *,
        extra_info: Optional[Dict] = None,
    ) -> np.ndarray:
        if extra_info is not None:
            return super().select_action(obs, deterministic, extra_info=extra_info)
        else:
            return super().select_action(obs, deterministic)


    def anchor_seeker_pretrain_reverse(self, load_reverse_imagination_path, n_epoch, batch_size, lr, asp_which, logger: Logger, data = None) -> None:
        modes_to_update = ['actor', 'critic'] if asp_which == 'both' else [asp_which]
        has_actor_anchor_seeker = hasattr(self.actor, 'anchor_seeking_actor')
        has_critic_anchor_seeker = hasattr(self.critic1, 'anchor_seeking_actor')
        assert has_actor_anchor_seeker or has_critic_anchor_seeker, "Either actor or critic should have anchor seeker"
        if 'actor' in modes_to_update:
            assert has_actor_anchor_seeker, "Actor should have anchor seeker"
        if 'critic' in modes_to_update:
            assert has_critic_anchor_seeker, "Critic should have anchor seeker"

        # only tensor version
        if load_reverse_imagination_path is not None:
            path = os.path.join(load_reverse_imagination_path, "reverse_imagination.pkl")
            with open(path, 'rb') as f:
                data = pickle.load(f)
        else:
            assert data != None, f"data should be given, but load_reverse_imagination_path is {data}"


        observations = torch.tensor(data["observations"],device=self.args.device)
        actions = torch.tensor(data["actions"], device=self.device)

        split_idx = int((1-self.args.as_holdout_ratio) * observations.shape[0])
        indices = list(range(observations.shape[0]))

        train_indices, eval_indices = indices[:split_idx], indices[split_idx:]

        train_observations, eval_observations = observations[train_indices], observations[eval_indices] if eval_indices else torch.tensor([])
        train_actions, eval_actions = actions[train_indices], actions[eval_indices] if eval_indices else torch.tensor([])

        train_init_obss, eval_init_obss = None, None
        train_one_hot_rollout_length, eval_one_hot_rollout_length = None, None

        sample_num = train_observations.shape[0]

        self.dynamics.model.eval()
        self.dynamics.model.requires_grad_(False)
        self._optim_actor_anchor_seeker = torch.optim.Adam(self.actor.anchor_seeking_actor.parameters(), lr=lr) if has_actor_anchor_seeker else None
        self._optim_critic_anchor_seeker = torch.optim.Adam(self.critic1.anchor_seeking_actor.parameters(), lr=lr) if has_critic_anchor_seeker else None

        logger.log("Pretraining anchor seeker for actor and critic")

        self.actor.anchor_seeking_actor.train() if has_actor_anchor_seeker else None
        self.critic1.anchor_seeking_actor.train() if has_critic_anchor_seeker else None


        optimizers = {'actor': self._optim_actor_anchor_seeker, 'critic': self._optim_critic_anchor_seeker}
        old_losses = [1e10 for i in range(2)]
        cnt = 0
        max_epochs_since_update = 5
        for i_epoch in range(n_epoch):
            accumulated_losses = {'actor': 0, 'critic': 0}

            for i_batch in tqdm(range(math.ceil(sample_num // batch_size)), desc=f"Epoch{i_epoch+1}/{n_epoch}"):
                batch_obs = train_observations[i_batch * batch_size: (i_batch + 1) * batch_size]
                batch_actions = train_actions[i_batch * batch_size: (i_batch + 1) * batch_size]

                obs = batch_obs
                for mode in modes_to_update:
                    if mode =='actor' and has_actor_anchor_seeker:
                        pred_actions = self.actor.anchor_seeking_actor(obs)
                    elif mode=='critic' and has_critic_anchor_seeker:
                        pred_actions = self.critic1.anchor_seeking_actor(obs)
                    else:
                        raise NotImplementedError

                    loss = (((batch_actions - pred_actions) ** 2).mean())
                    optimizers[mode].zero_grad()
                    loss.backward()

                    optimizers[mode].step()
                    accumulated_losses[mode] += loss.cpu().item()

            logger.logkv(f"train/actor_anchor_seeking_loss",  accumulated_losses['actor']/(sample_num // batch_size))
            logger.logkv(f"train/critic_anchor_seeking_loss",  accumulated_losses['critic']/(sample_num // batch_size))
            logger.dumpkvs()


            if bool(self.args.as_holdout_ratio):
                self.anchor_seeker_reverse_eval(i_epoch, batch_size, modes_to_update, eval_observations, eval_actions, eval_one_hot_rollout_length, eval_init_obss, logger)

            # early stopping
            new_losses = [accumulated_losses["actor"], accumulated_losses["critic"]]
            indexes = []
            for i, new_loss, old_loss in zip(range(len(new_losses)), new_losses, old_losses):
                improvement = (old_loss - new_loss) / old_loss
                if abs(improvement) > 0.001: # stop when the loss converges
                    indexes.append(i)
            old_losses = new_losses # always update old_losses
            cnt = 0 if len(indexes) > 0 else cnt + 1

            if (cnt >= max_epochs_since_update):
                break

        for mode in modes_to_update:
            random_states = {}
            random_states["np"] = np.random.get_state() # dictionary
            random_states["torch"] = torch.get_rng_state() # Tensor
            random_states["torch_cuda"] = torch.cuda.get_rng_state_all() # List[Tensor]

            if mode =='actor' and has_actor_anchor_seeker:
                self.actor.anchor_seeking_actor.save(logger.model_dir, mode, random_states)
                logger.log(f"Saved actor anchor seeker in Dir: {os.path.join(logger.model_dir, mode)}")
            elif mode=='critic' and has_critic_anchor_seeker:
                self.critic1.anchor_seeking_actor.save(logger.model_dir, mode, random_states)
                logger.log(f"Saved critic anchor seeker in Dir: {os.path.join(logger.model_dir, mode)}")
            else:
                raise NotImplementedError


    def anchor_seeker_reverse_eval(self, i_epoch, batch_size, modes_to_update, eval_observations, eval_actions, eval_one_hot_rollout_length, eval_init_obss, logger:Logger):
        accumulated_losses = {'actor': 0, 'critic': 0}
        sample_num = eval_observations.shape[0]
        for i_batch in range(sample_num // batch_size):
            batch_obs = eval_observations[i_batch * batch_size: (i_batch + 1) * batch_size]
            batch_actions = eval_actions[i_batch * batch_size: (i_batch + 1) * batch_size]

            obs = batch_obs
            for mode in modes_to_update:
                if mode =='actor':
                    pred_actions = self.actor.anchor_seeking_actor(obs)
                elif mode=='critic':
                    pred_actions = self.critic1.anchor_seeking_actor(obs)
                else:
                    raise NotImplementedError
                loss = (((pred_actions - batch_actions) ** 2).mean()) # pred_actions = (batch_size, time_step, action_dim)
                accumulated_losses[mode] += loss.cpu().item()

        logger.logkv(f"eval/actor_anchor_seeking_loss",  accumulated_losses['actor']/(sample_num // batch_size))
        logger.logkv(f"eval/critic_anchor_seeking_loss",  accumulated_losses['critic']/(sample_num // batch_size))
        logger.dumpkvs()

    def learn(self, batch: Dict) -> Dict[str, float]:
        if len(batch.keys()) == 2: # realbuffer & fakebuffer
            real_batch, fake_batch = batch["real"], batch["fake"]
            mix_batch = {k: torch.cat([real_batch[k], fake_batch[k]], 0) for k in real_batch.keys()}
        else: # only realbuffer
            mix_batch = batch["real"]

        return super().learn(mix_batch)

    def save(self, save_path: str, random_states: dict, epoch, logger, lr_scheduler=None, last_10_performance=None) -> None:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        data = dict(
            state_dict = self.state_dict(),
            random_states = random_states,

            actor = self.actor.state_dict(),
            critic1 = self.critic1.state_dict(),
            critic2 = self.critic2.state_dict(),

            alpha = self._alpha,
            log_alpha = self._log_alpha,
            alpha_optim = self.alpha_optim.state_dict(),

            actor_optim = self.actor_optim.state_dict(),
            critic1_optim = self.critic1_optim.state_dict(),
            critic2_optim = self.critic2_optim.state_dict(),

            lr_scheduler = lr_scheduler.state_dict() if lr_scheduler else None,
            epoch=epoch,
            last_10_performance = list(last_10_performance)
        )
        torch.save(data, os.path.join(save_path, f"policy_{epoch:04}.pth"))
        logger.log(f"[Epoch: {epoch}] Saved policy to: {save_path}")

        # remove old file(old epoch) - idempotency
        ckpt_files = glob(os.path.join(save_path, "policy_*.pth"))
        for file in ckpt_files:
            match = re.search(r'policy_(\d+).pth$', file)
            if match:
                file_epoch = int(match.group(1))
                if file_epoch < epoch:
                    os.remove(file)
                    logger.log(f"[Epoch: {epoch}] Removed policy_{file_epoch:04}.pth")


    def load(self, load_path: str, logger, lr_scheduler, policy_trainer) -> None:
        data = torch.load(load_path, map_location=self.device)
        self.load_state_dict(data['state_dict'])

        # Load random states
        random.setstate(data['random_states']["random"])
        np.random.set_state(data['random_states']["np"])
        torch.set_rng_state(torch.ByteTensor(data['random_states']["torch"].cpu()))
        torch.cuda.set_rng_state_all([torch.ByteTensor(t.cpu()) for t in data['random_states']["torch_cuda"]])

        # Load state_dict for actors and critics
        self.actor.load_state_dict(data['actor'])
        self.critic1.load_state_dict(data['critic1'])
        self.critic2.load_state_dict(data['critic2'])

        # alpha
        self._alpha = data['alpha']
        self._log_alpha = data['log_alpha']
        self.alpha_optim = torch.optim.Adam([self._log_alpha], lr=self.args.alpha_lr)
        self.alpha_optim.load_state_dict(data['alpha_optim'])

        # Load state_dict for optimizers
        self.actor_optim.load_state_dict(data['actor_optim'])
        self.critic1_optim.load_state_dict(data['critic1_optim'])
        self.critic2_optim.load_state_dict(data['critic2_optim'])

        # Load epoch
        self._epoch = data['epoch'] + 1

        logger.log(f"[Epoch: {self._epoch}] Loaded policy from: {load_path}")

        policy_trainer.lr_scheduler = lr_scheduler if lr_scheduler is not None else None
        policy_trainer.last_10_performance = last_10_performance


    def set_anchor_mode(self, anchor_mode):
        self.actor.anchor_mode = anchor_mode
        self.critic1.anchor_mode = anchor_mode
        self.critic2.anchor_mode = anchor_mode
        self.critic1_old.anchor_mode = anchor_mode
        self.critic2_old.anchor_mode = anchor_mode
        self.reset_anchor_sharing()

    def anchor_seeking_actor_freeze(self):
        self.actor.anchor_seeking_actor.freeze()
        self.critic1.anchor_seeking_actor.freeze()
        self.critic2.anchor_seeking_actor.freeze()
        self.critic1_old.anchor_seeking_actor.freeze()
        self.critic2_old.anchor_seeking_actor.freeze()

    def anchor_seeking_actor_unfreeze(self):
        self.actor.anchor_seeking_actor.unfreeze()
        self.critic1.anchor_seeking_actor.unfreeze()
        self.critic2.anchor_seeking_actor.unfreeze()
        self.critic1_old.anchor_seeking_actor.unfreeze()
        self.critic2_old.anchor_seeking_actor.unfreeze()


