import numpy as np
import torch
import torch.nn as nn
import gym
from copy import deepcopy
import sys

from torch.nn import functional as F
from typing import Dict, Union, Tuple
from offlinerlkit.policy import SACPolicy

LAMBDA = 0.6

class DRPolicy(SACPolicy):
    """
    Conservative Q-Learning <Ref: https://arxiv.org/abs/2006.04779>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        action_space: gym.spaces.Space,
        tau: float = 0.005,
        gamma: float  = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
        cql_weight: float = 1.0,
        temperature: float = 1.0,
        max_q_backup: bool = False,
        deterministic_backup: bool = True,
        with_lagrange: bool = True,
        lagrange_threshold: float = 10.0,
        cql_alpha_lr: float = 1e-4,
        num_repeart_actions:int = 10,
    ) -> None:
        super().__init__(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            alpha=alpha
        )

        self.action_space = action_space
        self._cql_weight = cql_weight
        self._temperature = temperature
        self._max_q_backup = max_q_backup
        self._deterministic_backup = deterministic_backup
        self._with_lagrange = with_lagrange
        self._lagrange_threshold = lagrange_threshold

        self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.actor.device)
        self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)

        self._num_repeat_actions = num_repeart_actions
        self._is_auto_alpha = False
        self.actor_global =  deepcopy(actor)

    def get_actor_global_actions(
        self,
        obs: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dist = self.actor_global(obs)
        if deterministic:
            squashed_action, raw_action = dist.mode()
        else:
            squashed_action, raw_action = dist.rsample()
        log_prob = dist.log_prob(squashed_action, raw_action)
        return squashed_action, log_prob
    
    def calc_pi_values(
        self,
        obs_pi: torch.Tensor,
        obs_to_pred: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        act, log_prob = self.actforward(obs_pi)

        q1 = self.critic1(obs_to_pred, act)
        q2 = self.critic2(obs_to_pred, act)

        return q1 - log_prob.detach(), q2 - log_prob.detach()

    def calc_random_values(
        self,
        obs: torch.Tensor,
        random_act: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q1 = self.critic1(obs, random_act)
        q2 = self.critic2(obs, random_act)

        log_prob1 = np.log(0.5**random_act.shape[-1])
        log_prob2 = np.log(0.5**random_act.shape[-1])

        return q1 - log_prob1, q2 - log_prob2
    
    
    
    def approximate_tv_distance_continuous(self, dist_client, dist_global, obss, 
                                       num_samples=1000, bins=100):
        """
        近似计算两个连续分布之间的总变分距离（TV距离）。
        
        Args:
            dist_client (torch.distributions.Distribution): 客户端策略的分布。
            dist_global (torch.distributions.Distribution): 全局策略的分布。
            obss (torch.Tensor): 当前批次的状态，形状 [batch_size, ...]。
            num_samples (int): 每个状态采样的动作数量。
            bins (int): 直方图的箱数。
        
        Returns:
            torch.Tensor: 每个样本的近似 TV 距离, shape = [batch_size].
        """
        device = self.actor.device
        batch_size = obss.shape[0]
        
        # 用来存放每个样本的 TV 距离
        tv_distance = torch.zeros(batch_size, device=device)
        
        # 如果 dist_client, dist_global 是批量分布，通常可以这样写：
        # samples_client = dist_client.sample((num_samples,))  # [num_samples, batch_size, action_dim] (视情况而定)
        # 本示例延续您原先“逐个状态”处理的做法
        for i in range(batch_size):
            # 从客户端和全局策略中采样 (若 dist_xxx 是 batch 分布，需要在 sample 时指定第 i 个)
            samples_client = dist_client.sample((num_samples,)).to(device)  # [num_samples, action_dim]
            samples_global = dist_global.sample((num_samples,)).to(device)  # [num_samples, action_dim]
            
            # 如果动作是多维的，计算每个维度的 TV 距离并取平均
            if samples_client.ndim > 1 and samples_client.shape[1] > 1:
                tv_per_dim = []
                
                for dim in range(samples_client.shape[1]):
                    client_dim_samples = samples_client[:, dim]  # shape [num_samples]
                    global_dim_samples = samples_global[:, dim]
                    
                    # 计算该维度的 min, max，用于后续 hist 归一化
                    min_val = torch.min(torch.cat([client_dim_samples, global_dim_samples]))
                    max_val = torch.max(torch.cat([client_dim_samples, global_dim_samples]))
                    
                    # 如果所有样本都一样，避免除零或 histc 报错
                    if min_val == max_val:
                        tv_dim = 0.0
                    else:
                        bin_width = (max_val - min_val) / bins
                        
                        # 计算直方图（计数）
                        hist_client = torch.histc(client_dim_samples, 
                                                bins=bins, 
                                                min=min_val.item(), 
                                                max=max_val.item())
                        hist_global = torch.histc(global_dim_samples, 
                                                bins=bins, 
                                                min=min_val.item(), 
                                                max=max_val.item())
                        
                        # 转成概率密度
                        # histc 得到的就是 [bins,] 大小的一维向量
                        # sum() 得到总计数，乘以 bin_width 之后得到分母
                        hist_client_density = hist_client / (hist_client.sum() * bin_width)
                        hist_global_density = hist_global / (hist_global.sum() * bin_width)
                        
                        # 计算 TV
                        tv_dim = 0.5 * torch.sum(torch.abs(hist_client_density - hist_global_density)) * bin_width
                    
                    tv_per_dim.append(tv_dim)
                
                # 维度间取平均
                tv_distance[i] = torch.mean(torch.tensor(tv_per_dim, device=device))
            
            else:
                # 单维动作空间 (或者是单维度)
                client_dim_samples = samples_client.view(-1)  # shape [num_samples]
                global_dim_samples = samples_global.view(-1)
                
                min_val = torch.min(torch.cat([client_dim_samples, global_dim_samples]))
                max_val = torch.max(torch.cat([client_dim_samples, global_dim_samples]))
                
                if min_val == max_val:
                    tv = 0.0
                else:
                    bin_width = (max_val - min_val) / bins
                    
                    hist_client = torch.histc(client_dim_samples, 
                                            bins=bins, 
                                            min=min_val.item(), 
                                            max=max_val.item())
                    hist_global = torch.histc(global_dim_samples, 
                                            bins=bins, 
                                            min=min_val.item(), 
                                            max=max_val.item())
                    
                    hist_client_density = hist_client / (hist_client.sum() * bin_width)
                    hist_global_density = hist_global / (hist_global.sum() * bin_width)
                    
                    tv = 0.5 * torch.sum(torch.abs(hist_client_density - hist_global_density)) * bin_width
                
                tv_distance[i] = tv
        
        return tv_distance

    def learn(self, batch: Dict, flag:bool) -> Dict[str, float]:
        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        batch_size = obss.shape[0]
        
        # update actor
        a, log_probs = self.actforward(obss)
        q1a, q2a = self.critic1(obss, a), self.critic2(obss, a)
        # print(obss)
        
        dist = self.actor(obss)
        local_log_probs_kl = dist.log_prob(actions)
        
        global_actions, _ = self.get_actor_global_actions(obss)
        global_log_probs_kl = dist.log_prob(global_actions)
        # tmp_actions = torch.clip(actions, -1. , 1. )
        # tmp_a = torch.clip(a, -1. , 1. )
        
        # print(min(local_log_probs_kl))
        # sys.exit()

        a_loss = (self._alpha * log_probs - torch.min(q1a, q2a)).mean()
        b_loss = - LAMBDA * (local_log_probs_kl).mean()
        c_loss = - (1 - LAMBDA) * (global_log_probs_kl).mean()
        actor_loss = a_loss + b_loss + c_loss
        # actor_loss = (self._alpha * log_probs - torch.min(q1a, q2a)).mean()  # min

        # 获取客户端策略分布
        dist_client = self.actor(obss)
        # 获取全局策略分布
        dist_global = self.actor_global(obss)  # 假设您有一个全局策略网络
        # 计算 TV 距离
        tv_distance = torch.tensor(0.)
        if flag:
            tv_distance = self.approximate_tv_distance_continuous(dist_client, dist_global, obss, num_samples=10, bins=10).max()
        # print(tv_distance)
        # print("客户端策略和全局策略的 TV 距离:", tv_distance)
        # print(tv_distance.shape)
        # sys.exit()
        
        
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()
        
        # compute td error
        if self._max_q_backup:
            with torch.no_grad():
                tmp_next_obss = next_obss.unsqueeze(1) \
                    .repeat(1, self._num_repeat_actions, 1) \
                    .view(batch_size * self._num_repeat_actions, next_obss.shape[-1])
                tmp_next_actions, _ = self.actforward(tmp_next_obss)
                tmp_next_q1 = self.critic1_old(tmp_next_obss, tmp_next_actions) \
                    .view(batch_size, self._num_repeat_actions, 1) \
                    .max(1)[0].view(-1, 1)
                tmp_next_q2 = self.critic2_old(tmp_next_obss, tmp_next_actions) \
                    .view(batch_size, self._num_repeat_actions, 1) \
                    .max(1)[0].view(-1, 1)
                next_q = torch.min(tmp_next_q1, tmp_next_q2)  # min
                # next_q = torch.mean(torch.cat([tmp_next_q1, tmp_next_q2], dim=0))  # mean
        else:
            with torch.no_grad():
                next_actions, next_log_probs = self.actforward(next_obss)
                next_q = torch.min(
                    self.critic1_old(next_obss, next_actions),
                    self.critic2_old(next_obss, next_actions)
                ) # min
                # next_q = torch.mean(torch.cat([
                #     self.critic1_old(next_obss, next_actions),
                #     self.critic2_old(next_obss, next_actions)], dim=0)
                # ) # mean


                if not self._deterministic_backup:
                    next_q -= self._alpha * next_log_probs

        target_q = rewards + self._gamma * (1 - terminals) * next_q
        q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions)
        critic1_loss = ((q1 - target_q).pow(2)).mean()
        critic2_loss = ((q2 - target_q).pow(2)).mean()
        q1_loss, q2_loss = critic1_loss, critic2_loss


        # compute conservative loss
        random_actions = torch.FloatTensor(
            batch_size * self._num_repeat_actions, actions.shape[-1]
        ).uniform_(self.action_space.low[0], self.action_space.high[0]).to(self.actor.device)
        # tmp_obss & tmp_next_obss: (batch_size * num_repeat, obs_dim)
        tmp_obss = obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        tmp_next_obss = next_obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        
        obs_pi_value1, obs_pi_value2 = self.calc_pi_values(tmp_obss, tmp_obss)
        next_obs_pi_value1, next_obs_pi_value2 = self.calc_pi_values(tmp_next_obss, tmp_obss)
        random_value1, random_value2 = self.calc_random_values(tmp_obss, random_actions)

        for value in [
            obs_pi_value1, obs_pi_value2, next_obs_pi_value1, next_obs_pi_value2,
            random_value1, random_value2
        ]:
            value.reshape(batch_size, self._num_repeat_actions, 1)
        
        # cat_q shape: (batch_size, 3 * num_repeat, 1)
        cat_q1 = torch.cat([obs_pi_value1, next_obs_pi_value1, random_value1], 1)
        cat_q2 = torch.cat([obs_pi_value2, next_obs_pi_value2, random_value2], 1)
        
        # print(q1.size())
        # print(cat_q1.size())
        # print(torch.logsumexp(cat_q1 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature)
        # print(q1.mean() * self._cql_weight)

        conservative_loss1 = \
            torch.logsumexp(cat_q1 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \
            q1.mean() * self._cql_weight
        conservative_loss2 = \
            torch.logsumexp(cat_q2 / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \
            q2.mean() * self._cql_weight
        
        if self._with_lagrange:
            cql_alpha = torch.clamp(self.cql_log_alpha.exp(), 0.0, 1e6)
            conservative_loss1 = cql_alpha * (conservative_loss1 - self._lagrange_threshold)
            conservative_loss2 = cql_alpha * (conservative_loss2 - self._lagrange_threshold)

            self.cql_alpha_optim.zero_grad()
            cql_alpha_loss = -(conservative_loss1 + conservative_loss2) * 0.5
            cql_alpha_loss.backward(retain_graph=True)
            self.cql_alpha_optim.step()
        
        critic1_loss = critic1_loss + conservative_loss1
        critic2_loss = critic2_loss + conservative_loss2

        # update critic
        self.critic1_optim.zero_grad()
        critic1_loss.backward(retain_graph=True)
        self.critic1_optim.step()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()

        self._sync_weight()

        result =  {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
            "loss/q1": q1_loss.item(),
            "loss/q2": q2_loss.item(),
            "loss/aloss": a_loss.item(),
            "loss/bloss": b_loss.item(),
            "loss/closs": c_loss.item(),
            "loss/TV": tv_distance.item(),
        }

        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()
        else:
            result["alpha"] = self._alpha.item()
        if self._with_lagrange:
            result["loss/cql_alpha"] = cql_alpha_loss.item()
            result["cql_alpha"] = cql_alpha.item()
        
        return result

