import time
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from torch import nn
from torch.distributions import kl_divergence
import torch.nn.functional as F

from fsrl.policy import BasePolicy
from fsrl.utils import BaseLogger


class CVPO_Disc(BasePolicy):
    """Implementation of the Constrained Variational Policy Optimization (CVPO).

    More details, please refer to https://arxiv.org/abs/2201.11927.

    :param torch.nn.Module actor: the actor network following the rules in
        :class:`~fsrl.policy.BasePolicy`. (s -> logits)
    :param Union[nn.Module, List[nn.Module]] critics: the critic network(s). (s -> V(s))
    :param torch.optim.Optimizer actor_optim: the optimizer for the actor network.
    :param torch.optim.Optimizer critic_optim: the optimizer for the critic network(s).
    :param gym.Space action_space: the action space of the environment.
    :param Type[torch.distributions.Distribution] dist_fn: the probability distribution
        function for sampling actions.
    :param int max_episode_steps: the maximum number of steps per episode for computing
        the step-wise qc threshold.
    :param Optional[BaseLogger] logger: the logger instance for logging training
        information. (default=DummyLogger)
    :param Union[List, float] cost_limit: the constraint limit(s) for the optimization.
        (default=np.inf)
    :param float tau: target smoothing coefficient for soft update of target networks.
        (default=0.05)
    :param float gamma: the discount factor for future rewards. (default=0.99)
    :param int n_step: number of steps for multi-step learning. (default=2)
    :param int estep_iter_num: the number of iterations for the E-step. (default=1)
    :param float estep_kl: the KL divergence threshold for the E-step. (default=0.02)
    :param float estep_dual_max: the maximum value for the dual variable in the E-step.
        (default=20)
    :param float estep_dual_lr: the learning rate for the dual variable in the E-step.
        (default=0.02)
    :param int sample_act_num: the number of actions to sample for the E-step.
        (default=16)
    :param int mstep_iter_num: the number of iterations for the M-step. (default=1)
    :param float mstep_kl_mu: the KL divergence threshold for the M-step (mean).
        (default=0.005)
    :param float mstep_kl_std: the KL divergence threshold for the M-step (standard
        deviation). (default=0.0005)
    :param float mstep_dual_max: the maximum value for the dual variable in the M-step.
        (default=0.5)
    :param float mstep_dual_lr: the learning rate for the dual variable in the M-step.
        (default=0.1)
    :param bool deterministic_eval: whether to use deterministic action selection during
        evaluation. (default=True)
    :param bool action_scaling: whether to scale the actions according to the action
        space bounds. (default=True)
    :param str action_bound_method: the method for handling actions that exceed the
        action space bounds ("clip" or other custom methods). (default="clip")
    :param Optional[torch.optim.lr_scheduler.LambdaLR] lr_scheduler: learning rate
        scheduler for the optimizer.

    .. seealso::

        Please refer to :class:`~fsrl.policy.BasePolicy` for more detailed hyperparameter
        explanations and usage.
    """

    def __init__(
        self,
        actor: nn.Module,
        critics: Union[nn.Module, List[nn.Module]],
        actor_optim: torch.optim.Optimizer,
        critic_optim: torch.optim.Optimizer,
        action_space: gym.Space,
        # CVPO specific arguments
        dist_fn: Type[torch.distributions.Distribution],
        max_episode_steps: int,
        logger: Optional[BaseLogger] = BaseLogger(),
        cost_limit: Union[List, float] = np.inf,
        tau: float = 0.05,
        gamma: float = 0.99,
        n_step: int = 2,
        # E-step
        estep_iter_num: int = 1,
        estep_kl: float = 0.02,
        estep_dual_max: float = 20,
        estep_dual_lr: float = 0.02,
        sample_act_num: int = 16,
        # M-step
        mstep_iter_num: int = 1,
        # mstep_kl_mu: float = 0.005,
        # mstep_kl_std: float = 0.0005,
        mstep_kl: float = 0.005,
        mstep_dual_max: float = 0.5,
        mstep_dual_lr: float = 0.1,
        # other param
        deterministic_eval: bool = True,
        action_scaling: bool = False,
        action_bound_method: str = "clip",
        lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
        disc_sample: bool = True,
        cost_shield_distill = None,
        actor_shield_distill = None,
        cost_shield_optim = None,
        actor_shield_optim = None
    ) -> None:
        super().__init__(
            actor=actor,
            critics=critics,
            dist_fn=dist_fn,
            logger=logger,
            gamma=gamma,
            deterministic_eval=deterministic_eval,
            action_scaling=action_scaling,
            action_bound_method=action_bound_method,
            action_space=action_space,
            lr_scheduler=lr_scheduler
        )

        self.actor_old = deepcopy(self.actor)
        self.actor_old.eval()
        self.actor_optim = actor_optim
        self.critics_old = deepcopy(self.critics)
        self.critics_old.eval()
        self.critics_optim = critic_optim
        self.device = next(self.actor.parameters()).device
        self.dtype = next(self.actor.parameters()).dtype
        self.cost_limit = [cost_limit] * (self.critics_num -
                                          1) if np.isscalar(cost_limit) else cost_limit

        self.max_episode_steps = max_episode_steps
        # print(self._gamma, self.max_episode_steps)
        # qc threshold in the E-step
        self.qc_thres = [
            c * (1 - self._gamma**self.max_episode_steps) / (1 - self._gamma) /
            self.max_episode_steps for c in self.cost_limit
        ]

        # E-step init
        self._estep_kl = estep_kl
        self._estep_iter_num = estep_iter_num
        self._estep_dual_max = estep_dual_max
        self._estep_dual_lr = estep_dual_lr
        self._sample_act_num = sample_act_num
        # the first dim is eta, others are lambda in the paper
        d = np.zeros(self.critics_num)
        d[0] = 1  # init eta to be 1
        self.estep_dual = torch.tensor(
            d, requires_grad=True, device=self.device, dtype=self.dtype
        )
        self.estep_optim = torch.optim.Adam([self.estep_dual], lr=self._estep_dual_lr)

        # M-step init
        self._mstep_kl = mstep_kl
        # self._mstep_kl_std = mstep_kl_std
        self._mstep_iter_num = mstep_iter_num
        self._mstep_dual_max = mstep_dual_max
        self._mstep_dual_lr = mstep_dual_lr

        self._estep_duration = 0
        self._mstep_duration = 0
        self._shield_duration = 0

        assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
        self.tau = tau
        self.action_type = "discrete"
        self._discrete = True
        self._n_step = n_step
        self.__eps = np.finfo(np.float32).eps.item() * 10  # around 1e-6

        self.action_dim = self.action_space.n
        self.disc_sample = disc_sample

        self.cost_shield_distill = cost_shield_distill
        self.actor_shield_distill = actor_shield_distill
        self.cost_shield_optim = cost_shield_optim
        self.actor_shield_optim = actor_shield_optim
        if self.cost_shield_distill is not None:
            self.update_shield = True

    def update_cost_limit(self, cost_limit: float):
        """Update the cost limit threshold.

        :param float cost_limit: new cost threshold
        """
        self.cost_limit = [cost_limit] * (self.critics_num -
                                          1) if np.isscalar(cost_limit) else cost_limit

        self.qc_thres = [
            c * (1 - self._gamma**self.max_episode_steps) / (1 - self._gamma) /
            self.max_episode_steps for c in self.cost_limit
        ]

    def pre_update_fn(self, **kwarg: Any) -> Any:
        """Init the mstep optimizer and dual variables."""
        self.mstep_dual = torch.zeros(
            1, requires_grad=True, device=self.device, dtype=self.dtype
        )
        # self.mstep_dual_std = torch.zeros(
        #     1, requires_grad=True, device=self.device, dtype=self.dtype
        # )
        self.mstep_optim = torch.optim.Adam(
            [self.mstep_dual], lr=self._mstep_dual_lr
        )

    def post_update_fn(self, **kwarg: Any) -> Any:
        """Update the old actor network."""
        with torch.no_grad():
            self.actor_old.load_state_dict(self.actor.state_dict())

    def train(self, mode: bool = True):
        """Set the module in training mode, except for the target network."""
        self.training = mode
        self.actor.train(mode)
        self.critics.train(mode)
        if self.update_shield:
            self.cost_shield_distill.train(mode)
            self.actor_shield_distill.train(mode)
        return self

    def sync_weight(self) -> None:
        """Soft-update the weight for the target network."""
        self.soft_update(self.critics_old, self.critics, self.tau)

    def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> List[torch.Tensor]:
        batch = buffer[indices]  # batch.obs_next: s_{t+n}
        obs_next_result = self(batch, model="actor", input='obs_next')
        act_one_hot = obs_next_result.act_one_hot
        target_q_list = []
        for i in range(self.critics_num):
            target_q, _ = self.critics_old[i].predict(batch.obs_next, act_one_hot)
            target_q_list.append(target_q)
        return target_q_list

    def process_fn(
        self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
    ) -> Batch:
        batch = self.compute_nstep_returns(
            batch, buffer, indices, self._target_q, self._n_step
        )
        return batch

    def forward(  # type: ignore
        self,
        batch: Batch,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        model: str = "actor",
        input: str = "obs",
        **kwargs: Any,
    ) -> Batch:
        model = getattr(self, model)
        obs = batch[input]
        logits, hidden = model(obs, state=state)
        if isinstance(logits, tuple):
            dist = self.dist_fn(*logits)
        else:
            dist = self.dist_fn(logits=logits)
        if self._deterministic_eval and not self.training:
            if self.action_type == "discrete":
                act = logits.argmax(-1)
            elif self.action_type == "continuous":
                act = logits[0]
        else:
            act = dist.sample()
        # print(act, logits.argmax(-1))
        # print(logits.argmax(-1).shape)
        # print(act.shape)
        act_one_hot = F.one_hot(act.to(torch.long), num_classes=self.action_dim).to(torch.float32)
        return Batch(logits=logits, act=act, act_one_hot=act_one_hot, state=hidden, dist=dist)

    def critics_loss(
        self, batch: Batch, critics: torch.nn.Module, optimizer: torch.optim.Optimizer
    ) -> Tuple[torch.Tensor, dict]:
        """A simple wrapper script for updating critic network."""
        # weight = getattr(batch, "weight", 1.0)
        weight = batch.info.get("weight", np.ones(batch.rew.shape))
        target_q_temp = batch.rets[..., 0].flatten()
        weight = to_torch_as(weight, target_q_temp).reshape(-1,1)
        loss_critic = 0
        td_average = 0
        stats_critic = {}
        for i in range(self.critics_num):
            target_q = batch.rets[..., i].flatten()
            # double q network
            act = to_torch_as(batch.act, target_q).reshape(-1,1)
            # print(act.shape)
            bs, x = act.shape
            if act.shape[-1] == 1:
                act = F.one_hot(act.to(torch.long), num_classes=self.action_dim).to(torch.float32).reshape(bs, self.action_dim)
            # print(act.shape)
            # assert act.shape[0] == bs
            # assert act.shape[1] == self.action_dim
            
            current_q_list = critics[i](batch.obs, act)
            loss_i = 0
            for j in range(len(current_q_list)):
                td = current_q_list[j].flatten() - target_q
                td_average += td
                loss_i += (td.pow(2) * weight).mean()

            loss_critic += loss_i
            stats_critic["loss/loss_q" + str(i)] = loss_i.item()
            stats_critic["estep/val_q" + str(i)] = torch.mean(target_q).item()
            if i >= 1:
                stats_critic["estep/thres_q" + str(i)] = self.qc_thres[i - 1]
        optimizer.zero_grad()
        loss_critic.backward()
        optimizer.step()
        td_average /= self.critics_num * 2
        stats_critic["loss/q_total"] = loss_critic.item()
        return td_average, stats_critic

    def _estep_dual_loss(self, q_values):
        eta = self.estep_dual[0]
        K = q_values[0].shape[-1]
        loss = eta * self._estep_kl
        combined_q = q_values[0].detach()  # (B, K)
        for i in range(1, self.critics_num):
            combined_q -= self.estep_dual[i] * q_values[i].detach()
            loss += self.estep_dual[i] * self.qc_thres[i - 1]
        loss += eta * torch.mean(torch.logsumexp(combined_q / eta, dim=1) - np.log(K))
        return loss

    @staticmethod
    def gaussian_kl(
        mu_old: torch.Tensor, std_old: torch.Tensor, mu: torch.Tensor, std: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Decoupled KL between two multivariate Gaussians with diagonal covariance.

        See https://arxiv.org/pdf/1812.02256.pdf Sec. 4.2.1 for details. kl_mu = KL(
        pi(mu_old, std_old) || pi(mu, std_old) ) kl_std = KL( pi(mu_old, std_old) ||
        pi(mu_old, std) )

        :param mu_old: (B, n)
        :param mu: (B, n)
        :param std_old: (B, n)
        :param std: (B, n)
        :return: kl_mu, kl_std: scalar mean and covariance terms of the KL
        """
        var_old, var = std_old**2, std**2
        # for numerical stability
        var_old = torch.clamp_min(var_old, 1e-6)
        var = torch.clamp_min(var, 1e-6)

        # note, this kl's demoninator is the old var rather than the new var
        kl_mu = 0.5 * (mu_old - mu)**2 / var_old
        kl_mu = torch.sum(kl_mu, dim=-1).mean()

        kl_std = 0.5 * (torch.log(var / var_old) + var_old / var - 1)
        kl_std = torch.sum(kl_std, dim=-1).mean()  # Sum over the dimensions

        return kl_mu, kl_std
    
    @staticmethod
    def categorical_kl(
        logit_old: torch.Tensor, logit: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        p = F.softmax(logit_old, dim=-1)
        _kl = torch.sum(p * (F.log_softmax(logit_old, dim=-1)
                                  - F.log_softmax(logit, dim=-1)), dim=-1).mean()
        return _kl

    def shield_loss(self, batch: Batch, **kwarg):

        # E-step begin
        t_start = time.time()
        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)
        action = torch.tensor(batch.act, dtype=torch.long)
        action_one_hot = F.one_hot(action, num_classes=self.action_dim)
        action_one_hot = to_torch_as(action_one_hot, obs)
        action_one_hot_rep = action_one_hot.unsqueeze(0).repeat(self.cost_shield_distill.num_ensemble, 1, 1)
        action_one_hot_rep_rev = (1 - action_one_hot_rep)/(self.action_dim-1)
        # (K, B)
        # action = to_torch_as(action, obs).unsqueeze(0).repeat(self.cost_shield_distill.num_ensemble, 1)
        cost = batch.info.get("cost", np.zeros(batch.rew.shape))
        distill_cost_weight = batch.info.get("distill_cost_weight", np.zeros(batch.rew.shape))
        distill_actor_weight = batch.info.get("distill_actor_weight", np.zeros(batch.rew.shape))
        distill_cost_weight = to_torch_as(distill_cost_weight, obs).reshape(-1,1).unsqueeze(0).repeat(self.cost_shield_distill.num_ensemble, 1, 1)
        distill_actor_weight = to_torch_as(distill_actor_weight, obs).reshape(-1,1).unsqueeze(0).repeat(self.cost_shield_distill.num_ensemble, 1, self.action_dim)
        cost = np.where(cost==0.0,cost,1.0)
        cost = to_torch_as(cost, obs).reshape(-1,1)
        cost_target = cost.unsqueeze(0).repeat(self.cost_shield_distill.num_ensemble, 1, 1)
        cost_target_rep = cost_target.repeat(1,1,self.action_dim)
        cost_shield_loss_func = nn.BCELoss(reduction = "none")
        actor_shield_target = torch.where(cost_target_rep==0, action_one_hot_rep, action_one_hot_rep_rev)

        ### update cost shield
        cost_shield_out = self.cost_shield_distill(obs, act=action_one_hot)
        cost_shield_loss = (cost_shield_loss_func(cost_shield_out, cost_target)*distill_cost_weight).mean((1,2)).sum(0)
        self.cost_shield_optim.zero_grad()
        cost_shield_loss.backward()
        self.cost_shield_optim.step()

        ### update actor shield
        actor_shield_logits = self.actor_shield_distill(obs)
        # (K, B)
        # print(actor_shield_logits.shape, actor_shield_target.shape, distill_actor_weight.shape)
        actor_shield_loss = ((actor_shield_logits - actor_shield_target) * distill_actor_weight).pow(2).mean((1,2)).sum(0)
        self.actor_shield_optim.zero_grad()
        actor_shield_loss.backward()
        self.actor_shield_optim.step()

        self.logger.store(
            tab="shield",
            cost_shield_loss=cost_shield_loss.item(),
            actor_shield_loss=actor_shield_loss.item(),
        )

        t_shield = time.time()
        self._shield_duration += t_shield - t_start
        self.logger.store(tab="shield", shield_train_time=self._shield_duration)

    def policy_loss(self, batch: Batch, **kwarg):

        # E-step begin
        t_start = time.time()
        obs = torch.as_tensor(
            batch.obs, device=self.device, dtype=torch.float32
        )  # (B, ds)
        # for continuous action space, sample K particles
        K = self._sample_act_num
        if self.disc_sample:
            K = self.action_dim
        B = obs.shape[0]
        da = batch.act.shape[-1]
        da_one_hot = self.action_dim
        ds = obs.shape[-1]
        with torch.no_grad():
            old_result = self(batch, model="actor_old", input="obs")
            old_dist = old_result.dist  # (B, da)
            if self.disc_sample:
                sample_act = torch.arange(0,self.action_dim).reshape(-1,1).unsqueeze(1).repeat(1,B,1)
            else:
                sample_act = old_dist.sample((K, ))  # (K, B, da)
            sample_act_one_hot = to_torch_as(F.one_hot(sample_act, num_classes=self.action_dim), obs)
            if self.disc_sample:
                expanded_obs = obs[None, ...].expand(self.action_dim, -1, -1)  # (K, B, ds)
            else:
                expanded_obs = obs[None, ...].expand(K, -1, -1)  # (K, B, ds)
            q_values = []
            # TODO, use critics old or the current?
            for i in range(self.critics_num):
                target_q, _ = self.critics[i].predict(
                    expanded_obs.reshape(-1, ds), sample_act_one_hot.reshape(-1, da_one_hot)
                )
                target_q = target_q.reshape(K, B)
                q_values.append(target_q.T)  # (critic_num, B, K)

        # optimize
        for _ in range(self._estep_iter_num):
            self.estep_optim.zero_grad()
            estep_loss = self._estep_dual_loss(q_values)
            estep_loss.backward()
            self.estep_optim.step()
            self.logger.store(tab="loss", estep_loss=estep_loss.item())
        self.estep_dual.data.clamp_(min=self.__eps, max=self._estep_dual_max)
        # detach the estep dual variable for M-step
        estep_dual = []
        for i in range(self.critics_num):
            estep_dual.append(self.estep_dual[i].detach().item())
            self.logger.store(**{"estep/dual" + str(i): estep_dual[i]})

        # compute the optimal non-parametric variational distribution
        optimal_q = q_values[0].T  # (K, B)
        for i in range(1, self.critics_num):
            optimal_q -= estep_dual[i] * q_values[i].T
        optimal_q = torch.softmax((optimal_q) / estep_dual[0], dim=0).detach()  # (K, B)

        t_estep = time.time()
        self._estep_duration += t_estep - t_start
        self.logger.store(tab="estep", estep_time=self._estep_duration)

        # M-step begin

        old_logits = old_result.logits
        old_logits = old_logits.detach()
        for _ in range(self._mstep_iter_num):
            result = self(batch, model="actor", input="obs")

            # MLE loss
            logits = result.logits
            dist = self.dist_fn(logits)
            # dist2 = self.dist_fn(mu_old, std)
            # print(sample_act.shape, optimal_q.shape)
            sample_act = to_torch_as(sample_act, obs)
            likelihood = dist.expand((K, B)).log_prob(sample_act.squeeze())
            # print(likelihood.shape, optimal_q.shape)
            loss_mle = -torch.mean(optimal_q * likelihood)

            # update dual variables to regularize the KL
            # kl_mu, kl_std = self.gaussian_kl(mu_old, std_old, mu, std)
            kl_logits = self.categorical_kl(old_logits, logits)
            mstep_dual_loss = self.mstep_dual * (self._mstep_kl - kl_logits).detach()
            self.mstep_optim.zero_grad()
            mstep_dual_loss.backward()
            self.mstep_optim.step()

            # KL loss
            dual_logits = np.clip(self.mstep_dual.item(), 0.0, self._mstep_dual_max)
            # dual_std = np.clip(self.mstep_dual_std.item(), 0.0, self._mstep_dual_max)
            loss_kl = dual_logits * (kl_logits - self._mstep_kl)

            loss_actor = loss_mle + loss_kl

            # optimize the policy network
            self.actor_optim.zero_grad()
            loss_actor.backward()
            self.actor_optim.step()

            entropy = torch.mean(dist.entropy()).item()

            self.logger.store(
                tab="mstep",
                mstep_kl=kl_logits.item(),
                mstep_loss_kl=loss_kl.item(),
                mstep_loss_mle=loss_mle.item(),
                mstep_loss_total=loss_actor.item(),
                mstep_dual=dual_logits,
                entropy=entropy
            )
        self._mstep_duration += time.time() - t_estep
        self.logger.store(tab="mstep", mstep_time=self._mstep_duration)
    
    def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
               **kwargs: Any) -> Dict[str, Any]:
        """Update the policy network and replay buffer.

        It includes 3 function steps: process_fn, learn, and post_process_fn. In
        addition, this function will change the value of ``self.updating``: it will be
        False before this function and will be True when executing :meth:`update`.

        :param int sample_size: 0 means it will extract all the data from the buffer,
            otherwise it will sample a batch with given sample_size.
        :param ReplayBuffer buffer: the corresponding replay buffer.

        :return: No return because all the info should be stored in the logger.
        """
        if buffer is None:
            return {}
        batch, indices = buffer.sample(sample_size)
        self.updating = True
        batch = self.process_fn(batch, buffer, indices)
        self.learn(batch, **kwargs)
        self.post_process_fn(batch, buffer, indices)
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        self.updating = False

    def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
        # critic
        td, stats_critic = self.critics_loss(batch, self.critics, self.critics_optim)
        batch.weight = td  # prio-buffer
        # actor
        self.policy_loss(batch)
        if self.update_shield:
            self.shield_loss(batch)

        self.sync_weight()
        self.logger.store(**stats_critic)

    def get_extra_state(self):
        """Save the dual variables and their optimizers.

        This function is called when call the policy.state_dict(), see
        https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_extra_state
        """
        # if len(self.lag_optims): return [optim.state_dict() for optim in
        #     self.lag_optims] else: return None

    def set_extra_state(self, state):
        """Load the dual variables and their optimizers.

        This function is called from load_state_dict() to handle any extra state found
        within the state_dict.
        """
