import numpy as np
import torch
import torch.nn as nn
import gym
from copy import deepcopy
import sys

from torch.nn import functional as F
from typing import Dict, Union, Tuple, Optional, List
from offlinerlkit.policy import SACPolicy
from offlinerlkit.policy import BasePolicy

LAMBDA = 1.0  # 0.6
BETA = 1.0
C=False # c

class PMoEPolicy(SACPolicy):
    """
    Conservative Q-Learning <Ref: https://arxiv.org/abs/2006.04779>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        action_space: gym.spaces.Space,
        tau: float = 0.005,
        gamma: float = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
        cql_weight: float = 1.0,
        temperature: float = 1.0,
        max_q_backup: bool = False,
        deterministic_backup: bool = True,
        with_lagrange: bool = True,
        lagrange_threshold: float = 10.0,
        cql_alpha_lr: float = 1e-4,
        num_repeart_actions: int = 10,
        lmbda: float = 1.0,
        beta: float = 1.0,
    ) -> None:
        super().__init__(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            alpha=alpha
        )

        self.action_space = action_space
        self._cql_weight = cql_weight
        self._temperature = temperature
        self._max_q_backup = max_q_backup
        self._deterministic_backup = deterministic_backup
        self._with_lagrange = with_lagrange
        self._lagrange_threshold = lagrange_threshold

        self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.actor.device)
        self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)

        self._num_repeat_actions = num_repeart_actions
        self._is_auto_alpha = True
        self.actor_global = deepcopy(actor)
        self.max_exp_scale = 1.0

        self.lmbda = lmbda
        self.beta = beta

        self.best_index=None

    def get_actor_global_actions(
        self,
        obs: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dist = self.actor_global(obs)
        if deterministic:
            squashed_action, raw_action = dist.mode()
        else:
            squashed_action, raw_action = dist.rsample()
        log_prob = dist.log_prob(squashed_action, raw_action)

        return squashed_action, log_prob, deterministic
    
    # def get_actor_pmoe_actions(   ### c
    #     self,
    #     obs: torch.Tensor,
    #     deterministic: bool = False
    # ) -> Tuple[torch.Tensor, torch.Tensor]:


    #     k, weight = self.actor.dist_net.get_mixing_coefficient(obs)  # shape=[1, 2]

    #     actions = [
    #         torch.as_tensor(p.select_action(obs.unsqueeze(0), deterministic=True))  # 确保转换为Tensor
    #         for p in self.policies
    #     ]
    #     actions = torch.stack(actions, dim=-1).to(self.actor.device)

    #     # 加权求和
    #     pmoe_action = (actions * weight.unsqueeze(-2)).sum(dim=-1).squeeze(0)
    #     return pmoe_action



    def calc_pi_values(
        self,
        obs_pi: torch.Tensor,
        obs_to_pred: torch.Tensor,
        local_action: torch.Tensor,
        pmoe_policy: Optional[float]=None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        act, log_prob = self.actforward(obs_pi)  # pi
        local_action, local_log_prob = local_action, torch.zeros_like(log_prob)  # pi_beta
        global_next_actions, global_next_log_probs, _ = self.get_actor_global_actions(obs_pi)  # pi_server

        pi_value_q1 = self.critic1(obs_to_pred, act) - log_prob.detach()
        pi_value_q2 = self.critic2(obs_to_pred, act) - log_prob.detach()

        local_pi_value_q1 = self.critic1(obs_to_pred, local_action) - local_log_prob.detach()
        local_pi_value_q2 = self.critic2(obs_to_pred, local_action) - local_log_prob.detach()

        global_pi_value_q1 = self.critic1(obs_to_pred, global_next_actions) - global_next_log_probs.detach()
        global_pi_value_q2 = self.critic2(obs_to_pred, global_next_actions) - global_next_log_probs.detach()

        if C == True:
            pmoe_next_actions, pmoe_next_log_probs = pmoe_policy.get_actor_pmoe_actions(obs_pi), torch.zeros_like(log_prob)  # pi_moe
            pmoe_pi_value_q1 = self.critic1(obs_to_pred, pmoe_next_actions) - pmoe_next_log_probs.detach()
            pmoe_pi_value_q2 = self.critic2(obs_to_pred, pmoe_next_actions) - pmoe_next_log_probs.detach()
            vote_pi_valie_q1_stack = torch.stack([pi_value_q1, local_pi_value_q1, global_pi_value_q1, pmoe_pi_value_q1])
            vote_pi_valie_q2_stack = torch.stack([pi_value_q2, local_pi_value_q2, global_pi_value_q2, pmoe_pi_value_q2])
        else:
            vote_pi_valie_q1_stack = torch.stack([pi_value_q1, local_pi_value_q1, global_pi_value_q1])
            vote_pi_valie_q2_stack = torch.stack([pi_value_q2, local_pi_value_q2, global_pi_value_q2])

        vote_pi_valie_q1, _ = torch.max(vote_pi_valie_q1_stack, dim=0)
        vote_pi_valie_q2, _ = torch.max(vote_pi_valie_q2_stack, dim=0)

        return vote_pi_valie_q1, vote_pi_valie_q2

    def calc_random_values(
        self,
        obs: torch.Tensor,
        random_act: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q1 = self.critic1(obs, random_act)
        q2 = self.critic2(obs, random_act)

        log_prob1 = np.log(0.5 ** random_act.shape[-1])
        log_prob2 = np.log(0.5 ** random_act.shape[-1])

        return q1 - log_prob1, q2 - log_prob2
    
    def update_pmax_actor(
        self,
        batch: Dict,
        policies: Optional[List[BasePolicy]] = None,
        best_q_values: Optional[torch.Tensor] = None,  # Q 值，可选
        best_actions: Optional[torch.Tensor] = None,  # 最佳动作，可选
        best_log_probs: Optional[torch.Tensor] = None,  # 最佳动作的对数概率，可选
        best_index: Optional[torch.Tensor] = None, # 最佳策略索引，可选
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        
        obss = batch["observations"]
        # 先计算固定的 a 和 log_probs
        a, log_probs = self.actforward(obss)

        # # 然后在列表推导式中复用它们
        # q_values, actions, log_probs = zip(*[
        #     (torch.min(p.critic1(obss, a), p.critic2(obss, a)), a, log_probs)
        #     for p in policies
        # ])
        # # 将结果堆叠成张量
        # q_tensor = torch.stack(q_values, dim=1)  # (batch_size, num_policies)
        # action_tensor = torch.stack(actions, dim=1)  # (batch_size, num_policies, action_dim)
        # log_prob_tensor = torch.stack(log_probs, dim=1)  # (batch_size, num_policies)

        # # 找到每个样本的最大 Q 值及其对应的策略索引
        # best_q_values, best_index = torch.max(q_tensor, dim=1)  # q_new_actions: (batch_size,), best_index: (batch_size,)
        # # print(q_tensor.shape)
        # # print(best_index.shape)

        # # 根据 best_index 提取对应的动作和对数概率
        # best_actions = action_tensor[torch.arange(q_tensor.size(0)), best_index]  # (batch_size, action_dim)
        # best_log_probs = log_prob_tensor[torch.arange(q_tensor.size(0)), best_index]  # (batch_size,)



        # 计算每个策略的两个 critic 值（不立即取 min）
        q1_values = [p.critic1(obss, a) for p in policies]  # List[Tensor], 每个元素 shape=(batch_size, 1)
        q2_values = [p.critic2(obss, a) for p in policies]  # List[Tensor], 每个元素 shape=(batch_size, 1)

        # 堆叠成张量 (batch_size, num_policies, 2)
        q_values = torch.stack([
            torch.stack(q1_values, dim=1),  # shape=(batch_size, num_policies)
            torch.stack(q2_values, dim=1)   # shape=(batch_size, num_policies)
        ], dim=-1)  # shape=(batch_size, num_policies, 2)

        # 先取 max（策略维度），再取 min（critic 维度）
        # 1. 对策略维度取 max
        max_q_values, best_policy_indices = torch.max(q_values, dim=1)  # max_q_values: (256, 2), best_policy_indices: (256, 2)
        best_q_values, min_indices = torch.min(max_q_values, dim=2)  # best_q_values: (256,), min_indices: (256,)
        best_index = torch.gather(best_policy_indices.squeeze(1), 1, min_indices)  # shape: (256,)

        self.k, mixing_coefficient= self.actor.dist_net.get_mixing_coefficient(obss)
        mixing_coefficient_loss = F.mse_loss(F.one_hot(best_index, self.k).float().squeeze(1), mixing_coefficient)
        # print(F.one_hot(best_index, self.k).float().squeeze(1))
        # print(F.one_hot(best_index, self.k).float().squeeze(1).sum())
        # print((F.one_hot(best_index, self.k).float().squeeze(1)).shape)
        # print(mixing_coefficient)
        # print(mixing_coefficient.sum())
        # # print(mixing_coefficient.shape)
        # sys.exit()





        a_loss = (self._alpha * log_probs - best_q_values).mean()
        

        actor_loss = a_loss

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        # sys.exit()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        return {
            # "loss/actor": actor_loss.item(),
            "moe_loss/aloss": a_loss.item(),
            # "loss/mixing_coefficient": mixing_coefficient_loss.item(),
            # "alpha": self._alpha.item(),
        }
    
    def update_pmoe_actor(
        self,
        batch: Dict,
        policies: Optional[List[BasePolicy]] = None,
        best_q_values: Optional[torch.Tensor] = None,  # Q 值，可选
        best_actions: Optional[torch.Tensor] = None,  # 最佳动作，可选
        best_log_probs: Optional[torch.Tensor] = None,  # 最佳动作的对数概率，可选
        best_index: Optional[torch.Tensor] = None, # 最佳策略索引，可选
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        
        obss = batch["observations"]
        # 先计算固定的 a 和 log_probs
        a, log_probs = self.actforward(obss)

        # 计算每个策略的两个 critic 值（不立即取 min）
        q1_values = [p.critic1(obss, a) for p in policies]  # List[Tensor], 每个元素 shape=(batch_size, 1)
        q2_values = [p.critic2(obss, a) for p in policies]  # List[Tensor], 每个元素 shape=(batch_size, 1)

        # 堆叠成张量 (batch_size, num_policies, 2)
        q_values = torch.stack([
            torch.stack(q1_values, dim=1),  # shape=(batch_size, num_policies)
            torch.stack(q2_values, dim=1)   # shape=(batch_size, num_policies)
        ], dim=-1)  # shape=(batch_size, num_policies, 2)

        # 先取 max（策略维度），再取 min（critic 维度）
        # 1. 对策略维度取 max
        max_q_values, best_policy_indices = torch.max(q_values, dim=1)  # max_q_values: (256, 2), best_policy_indices: (256, 2)
        best_q_values, min_indices = torch.min(max_q_values, dim=2)  # best_q_values: (256,), min_indices: (256,)
        best_index = torch.gather(best_policy_indices.squeeze(1), 1, min_indices)  # shape: (256,)

        self.k, mixing_coefficient= self.actor.dist_net.get_mixing_coefficient(obss)
        mixing_coefficient_loss = F.mse_loss(F.one_hot(best_index, self.k).float().squeeze(1), mixing_coefficient)
        

        actor_loss = mixing_coefficient_loss

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        # sys.exit()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        return {
            # "loss/actor": actor_loss.item(),
            # "loss/aloss": a_loss.item(),
            "moe_loss/mixing_coefficient": mixing_coefficient_loss.item(),
            # "alpha": self._alpha.item(),
        }

    def update_actor(self, batch: Dict, pmoe_policy: Optional[float] = None, policies: Optional[List] = None) -> Tuple[float, Dict[str, float]]:
        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        batch_size = obss.shape[0]
        # print(obss.shape)
        self.actions, self.next_actions = actions, batch["next_actions"]
        
        a, log_probs = self.actforward(obss)
        q1a, q2a = self.critic1(obss, a), self.critic2(obss, a)
        
        dist = self.actor(obss)
        local_log_probs_kl = dist.log_prob(actions)
        global_actions, _, _ = self.get_actor_global_actions(obss)
        global_log_probs_kl = dist.log_prob(global_actions)
        
        q1_pi, q2_pi = self.critic1_old(obss, a), self.critic2_old(obss, a)
        q1_local, q2_local = self.critic1_old(obss, actions), self.critic2_old(obss, actions)
        q1_gloabl, q2_gloabl = self.critic1_old(obss, global_actions), self.critic2_old(obss, global_actions)

        if C == True:  # c
            pmoe_actions = pmoe_policy.get_actor_pmoe_actions(obss)
            q1_pmoe, q2_pmoe = self.critic1_old(obss, pmoe_actions), self.critic2_old(obss, pmoe_actions)
            q1_piv_stack = torch.stack([q1_pi, q1_local, q1_gloabl, q1_pmoe]).to(self.actor.device)
            q2_piv_stack = torch.stack([q2_pi, q2_local, q2_gloabl, q2_pmoe]).to(self.actor.device)
        else:
            q1_piv_stack = torch.stack([q1_pi, q1_local, q1_gloabl]).to(self.actor.device)
            q2_piv_stack = torch.stack([q2_pi, q2_local, q2_gloabl]).to(self.actor.device)
        q1_piv, _ = torch.max(q1_piv_stack, dim=0)
        q2_piv, _ = torch.max(q2_piv_stack, dim=0)
        v = torch.min(q1_piv, q2_piv)
        
        adv_global = torch.min(q1_gloabl, q2_gloabl) - v
        exp_adv_global = torch.clip(torch.exp((1.0/self.beta)*adv_global), -self.max_exp_scale, self.max_exp_scale).to(self.actor.device)

        a_loss = (self._alpha * log_probs - torch.min(q1a, q2a)).mean() 
        b_loss = - self.lmbda * (exp_adv_global * local_log_probs_kl).mean()
        c_loss = - self.lmbda * (local_log_probs_kl).mean()
        # actor_loss = a_loss   # pmoe_a
        # actor_loss = a_loss + b_loss   # original
        actor_loss = a_loss + c_loss   # pmoe_ab

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()

        return actor_loss, {
            "loss/actor": actor_loss.item(),
            "loss/aloss": a_loss.item(),
            "loss/bloss": b_loss.item(),
            "loss/closs": c_loss.item(),
            "alpha": self._alpha.item(),
        }

    def update_critic(self, batch: Dict, pmoe_policy: Optional[float] = None, policies: Optional[List] = None) -> Tuple[float, Dict[str, float]]:
        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        batch_size = obss.shape[0]
        self.actions, self.next_actions = actions, batch["next_actions"]

        if self._max_q_backup:
            with torch.no_grad():
                tmp_next_obss = next_obss.unsqueeze(1) \
                    .repeat(1, self._num_repeat_actions, 1) \
                    .view(batch_size * self._num_repeat_actions, next_obss.shape[-1])
                tmp_next_actions, _ = self.actforward(tmp_next_obss)
                tmp_next_q1 = self.critic1_old(tmp_next_obss, tmp_next_actions) \
                    .view(batch_size, self._num_repeat_actions, 1) \
                    .max(1)[0].view(-1, 1)
                tmp_next_q2 = self.critic2_old(tmp_next_obss, tmp_next_actions) \
                    .view(batch_size, self._num_repeat_actions, 1) \
                    .max(1)[0].view(-1, 1)
                next_q = torch.min(tmp_next_q1, tmp_next_q2)
        else:
            with torch.no_grad():
                next_actions, next_log_probs = self.actforward(next_obss)
                global_next_actions, global_next_log_probs, _ = self.get_actor_global_actions(next_obss)
                
                if C == True: # c
                    pmoe_next_actions = pmoe_policy.get_actor_pmoe_actions(next_obss)
                    next_q1_piv_stack = torch.stack([
                        self.critic1_old(next_obss, next_actions),
                        self.critic1_old(next_obss, self.next_actions),
                        self.critic1_old(next_obss, global_next_actions),
                        self.critic1_old(next_obss, pmoe_next_actions),
                    ])              
                    next_q2_piv_stack = torch.stack([
                        self.critic2_old(next_obss, next_actions),
                        self.critic2_old(next_obss, self.next_actions),
                        self.critic2_old(next_obss, global_next_actions),
                        self.critic2_old(next_obss, pmoe_next_actions),
                    ])
                else:
                    next_q1_piv_stack = torch.stack([
                        self.critic1_old(next_obss, next_actions),
                        self.critic1_old(next_obss, self.next_actions),
                        self.critic1_old(next_obss, global_next_actions),
                    ])              
                    next_q2_piv_stack = torch.stack([
                        self.critic2_old(next_obss, next_actions),
                        self.critic2_old(next_obss, self.next_actions),
                        self.critic2_old(next_obss, global_next_actions),
                    ])


                next_q1_piv, _ = torch.max(next_q1_piv_stack, dim=0)
                next_q2_piv, _ = torch.max(next_q2_piv_stack, dim=0)
                
                next_q = torch.min(next_q1_piv, next_q2_piv)


                # # 然后在列表推导式中复用它们
                # q_values, actions, log_probs = zip(*[
                #     (torch.min(p.critic1(obss, a), p.critic2(obss, a)), a, log_probs)
                #     for p in policies
                # ])
                # # 将结果堆叠成张量
                # q_tensor = torch.stack(q_values, dim=1)  # (batch_size, num_policies)
                # action_tensor = torch.stack(actions, dim=1)  # (batch_size, num_policies, action_dim)
                # log_prob_tensor = torch.stack(log_probs, dim=1)  # (batch_size, num_policies)

                # # 找到每个样本的最大 Q 值及其对应的策略索引
                # best_q_values, best_index = torch.max(q_tensor, dim=1)  # q_new_actions: (batch_size,), best_index: (batch_size,)

                if not self._deterministic_backup:
                    next_q -= self._alpha * next_log_probs

        target_q = rewards + self._gamma * (1 - terminals) * next_q
        q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions)
        critic1_loss = ((q1 - target_q).pow(2)).mean()
        critic2_loss = ((q2 - target_q).pow(2)).mean()
        q1_loss, q2_loss = critic1_loss, critic2_loss

        random_actions = torch.FloatTensor(
            batch_size * self._num_repeat_actions, actions.shape[-1]
        ).uniform_(self.action_space.low[0], self.action_space.high[0]).to(self.actor.device)
        
        tmp_obss = obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        tmp_next_obss = next_obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
            
        tmp_actions = self.actions.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, self.actions.shape[-1])
        tmp_next_actions = self.next_actions.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, self.next_actions.shape[-1])
        
        obs_pi_value1, obs_pi_value2 = self.calc_pi_values(tmp_obss, tmp_obss, tmp_actions, pmoe_policy)
        next_obs_pi_value1, next_obs_pi_value2 = self.calc_pi_values(tmp_next_obss, tmp_obss, tmp_next_actions, pmoe_policy)
        random_value1, random_value2 = self.calc_random_values(tmp_obss, random_actions)

        for value in [
            obs_pi_value1, obs_pi_value2, next_obs_pi_value1, next_obs_pi_value2,
            random_value1, random_value2
        ]:
            value.reshape(batch_size, self._num_repeat_actions, 1)
        
        cat_q1 = torch.cat([obs_pi_value1, next_obs_pi_value1, random_value1], 1)
        cat_q2 = torch.cat([obs_pi_value2, next_obs_pi_value2, random_value2], 1)
        
        conservative_loss1 = \
            torch.logsumexp(cat_q1 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \
            q1.mean() * self._cql_weight
        conservative_loss2 = \
            torch.logsumexp(cat_q2 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \
            q2.mean() * self._cql_weight
            
        if self._with_lagrange:
            cql_alpha = torch.clamp(self.cql_log_alpha.exp(), 0.0, 1e6)
            conservative_loss1 = cql_alpha * (conservative_loss1 - self._lagrange_threshold)
            conservative_loss2 = cql_alpha * (conservative_loss2 - self._lagrange_threshold)

            self.cql_alpha_optim.zero_grad()
            cql_alpha_loss = -(conservative_loss1 + conservative_loss2) * 0.5
            cql_alpha_loss.backward(retain_graph=True)
            self.cql_alpha_optim.step()

        #     print("hi")
        # sys.exit()
        
        critic1_loss = critic1_loss + conservative_loss1
        critic2_loss = critic2_loss + conservative_loss2

        self.critic1_optim.zero_grad()
        critic1_loss.backward(retain_graph=True)
        self.critic1_optim.step()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()

        self._sync_weight()

        return critic1_loss, critic2_loss, {
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
            "loss/q1": q1_loss.item(),
            "loss/q2": q2_loss.item(),
            "cql_alpha": cql_alpha.item() if self._with_lagrange else 0.0,
        }