import numpy as np
import torch
import torch.nn as nn
import gym
import sys

from typing import Dict, Union, Tuple
from copy import deepcopy
from offlinerlkit.policy import BasePolicy


class CQLNSPolicy(BasePolicy):
    """
    Ensemble-Diversified Actor Critic <Ref: https://arxiv.org/abs/2110.01548>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic_init: nn.Module,
        critics: nn.ModuleList,
        actor_optim: torch.optim.Optimizer,
        critic_optim: torch.optim.Optimizer,
        critics_optim: torch.optim.Optimizer,
        action_space: gym.spaces.Space,
        tau: float = 0.005,
        gamma: float  = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
        max_q_backup: bool = False,
        deterministic_backup: bool = True,
        eta: float = 1.0,
        cql_weight: float = 1.0,
        temperature: float = 1.0,
        num_repeat_actions:int = 10,
    ) -> None:

        super().__init__()
        self.actor = actor
        self.critics = critics
        self.critics_old = deepcopy(critics)
        self.critics_old.eval()

        self.actor_optim = actor_optim
        self.critics_optim = critics_optim
        
        self.action_space = action_space

        self._tau = tau
        self._gamma = gamma

        self._is_auto_alpha = False
        if isinstance(alpha, tuple):
            self._is_auto_alpha = True
            self._target_entropy, self._log_alpha, self.alpha_optim = alpha
            self._alpha = self._log_alpha.detach().exp()
        else:
            self._alpha = alpha

        self._max_q_backup = max_q_backup
        self._deterministic_backup = deterministic_backup
        self._eta = eta
        self._cql_weight = cql_weight
        self._temperature = temperature
        self._num_repeat_actions = num_repeat_actions
        self._num_critics = self.critics._num_ensemble
        self.critic_list = [deepcopy(critic_init)] * self._num_critics
        
        

    def train(self) -> None:
        self.actor.train()
        self.critics.train()
        

    def eval(self) -> None:
        self.actor.eval()
        self.critics.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.critics_old.parameters(), self.critics.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
            
    def calc_pi_value(
        self,
        obs_pi: torch.Tensor,
        obs_to_pred: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        act, log_prob = self.actforward(obs_pi)

        q = self.critics(obs_to_pred, act)

        return q - log_prob.detach()

    def calc_random_value(
        self,
        obs: torch.Tensor,
        random_act: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q = self.critics(obs, random_act)

        log_prob = np.log(0.5**random_act.shape[-1])

        return q - log_prob
            
    def calc_pi_values(
        self,
        i: int,
        obs_pi: torch.Tensor,
        obs_to_pred: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        act, log_prob = self.actforward(obs_pi)

        q = self.critic_list[i](obs_to_pred, act)

        return q - log_prob.detach()

    def calc_random_values(
        self,
        i: int,
        obs: torch.Tensor,
        random_act: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q = self.critic_list[i](obs, random_act)

        log_prob = np.log(0.5**random_act.shape[-1])

        return q - log_prob

    
    def actforward(
        self,
        obs: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        dist = self.actor(obs)
        if deterministic:
            squashed_action, raw_action = dist.mode()
        else:
            squashed_action, raw_action = dist.rsample()
        log_prob = dist.log_prob(squashed_action, raw_action)
        return squashed_action, log_prob

    def select_action(
        self,
        obs: np.ndarray,
        deterministic: bool = False
    ) -> np.ndarray:
        with torch.no_grad():
            action, _ = self.actforward(obs, deterministic)
        return action.cpu().numpy()

    def learn_actor(self, batch: Dict) -> Dict[str, float]:
        self.obss, self.actions, self.next_obss, self.rewards, self.terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        batch_size = self.obss.shape[0]
        
        # update actor
        a, log_probs = self.actforward(self.obss)
        qas = self.critics(self.obss, a)
        actor_loss = -torch.min(qas, 0)[0].mean() + self._alpha * log_probs.mean()
        # optim actor
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()
        
    # def learn(self, batch: Dict) -> Dict[str, float]:
    def learn(self, batch: Dict, num_rewards: int) -> Dict[str, float]:
    #     self.learn_critic(batch)

    # def learn_critic(self, batch: Dict) -> Dict:
        obss, actions, next_obss, rewards, terminals = \
            batch["observations"], batch["actions"], batch["next_observations"], batch["rewards"], batch["terminals"]
        rewards_dict = {}
        for i in range(num_rewards):
            rewards_dict[f'reward{i}'] = batch[f"reward{i}"]
        #     print(rewards_dict[f'reward{i}'])
        # sys,exit()
        # print(rewards_dict["reward0"][-10].item())
        # print(batch["reward0"][-10].item())
        #     globals()[f"rewards{i}"] = rewards_batch[:, i]
        # for key, arr in rewards_arr.items():
            # setattr(self, str(key), arr)
        # for i in range(num_rewards):
        #     setattr(self, f'reward{i}', batch[f"reward{i}"])
        # print(self.reward0[-10].item())
        # sys.exit()

        batch_size = obss.shape[0]
        
        if self._eta > 0:
            actions.requires_grad_(True)

        # update actor
        a, log_probs = self.actforward(obss)
        # qas: [num_critics, batch_size, 1]
        qas = self.critics(obss, a)
        # print(qas.shape)
        # print(torch.min(qas,0))
        # print(torch.min(qas,0)[0])
        # print(torch.mean(qas,0).shape)
        # print(torch.std(qas,0).shape)
        # print((torch.mean(qas,0)[0]-torch.std(qas,0)[0]).mean())
        # # print(torch.mean(qas,0)[0])
        # sys.exit()
        actor_qs_loss = -torch.min(qas,0)[0].mean() + self._alpha * log_probs.mean()  #min
        # actor_qs_loss = -torch.mean(qas,0)[0].mean() + self._alpha * log_probs.mean() #mean
        # actor_loss = -torch.min(qas,0)[0].mean() + self._alpha * log_probs.mean()
        # actor_qs_loss = -(torch.mean(qas,0)[0]-torch.std(qas,0)[0]).mean() + self._alpha * log_probs.mean()    #std 
        actor_loss = actor_qs_loss
        
        #actor var
        # 均匀采样
        # if self._eta > 0:
        # only s
        actor_grad_loss = torch.tensor(0.0, dtype=torch.float32)
        if False:
            obss_tile = obss.unsqueeze(0).repeat(self._num_critics, 1, 1).requires_grad_(True)
            actions_tile = actions.unsqueeze(0).repeat(self._num_critics, 1, 1).requires_grad_(True)
            qs_preds_tile = self.critics(obss_tile, actions_tile)
            # qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), obss_tile, retain_graph=True, create_graph=True)
            # 分别对obss_tile和actions_tile计算偏导数
            qs_pred_grads_obss, = torch.autograd.grad(qs_preds_tile.sum(), obss_tile, retain_graph=True, create_graph=True)
            qs_pred_grads_actions, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
            # 调用actforward函数，假设它能处理批次化的输入
            a_tile, log_probs_tile = self.actforward(obss_tile)
            # 计算批次化的a_tile对obss_tile的梯度
            a_tile_grads, = torch.autograd.grad(outputs=a_tile.sum(), inputs=obss_tile, retain_graph=True, create_graph=True)
            # 将两个偏导数求和
            qs_pred_grads = qs_pred_grads_obss + (qs_pred_grads_actions.sum(dim=2, keepdim=True)) * a_tile_grads
            
            qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)
            qs_pred_grads = qs_pred_grads.transpose(0, 1)
            
            qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
            masks = torch.eye(self._num_critics, device=obss.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
            qs_pred_grads = (1 - masks) * qs_pred_grads
            actor_grad_loss = self._eta * torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self._num_critics - 1)

            actor_loss -=  actor_grad_loss
        
        #optim actor
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = torch.clamp(self._log_alpha.detach().exp(), 0.0, 1.0)

        
        # sys.exit()
        # update critic
        if self._max_q_backup:
            with torch.no_grad():
                batch_size = obss.shape[0]
                tmp_next_obss = next_obss.unsqueeze(1).repeat(1, 10, 1) \
                    .view(batch_size * 10, next_obss.shape[-1])
                tmp_next_actions, _ = self.actforward(tmp_next_obss)
                tmp_next_qs = self.critics_old(tmp_next_obss, tmp_next_actions) \
                    .view(self._num_critics, batch_size, 10, 1).max(2)[0] \
                    .view(self._num_critics, batch_size, 1)
                Next_q = tmp_next_qs
                next_q = tmp_next_qs.min(0)[0]
        else:
            with torch.no_grad():
                next_actions, next_log_probs = self.actforward(next_obss)
                Next_q = self.critics_old(next_obss, next_actions)
                next_q = self.critics_old(next_obss, next_actions).min(0)[0]
                # print("self._deterministic_backup: "+str(self._deterministic_backup))
                if not self._deterministic_backup:
                    next_q -= self._alpha * next_log_probs
        #             print("yes")
        # print(next_q.shape)
        # print(next_log_probs.shape)            
        # print(Next_q.shape)
        # print(rewards_dict["reward0"].shape)
        # print(Rewards.shape)
        # print(Target_q.shape)
        # print(target_q.shape)
        Rewards = torch.stack([rewards_dict[f"reward{i}"] for i in range(num_rewards)], dim=0)
        Next_log_probs = torch.stack([next_log_probs for i in range(num_rewards)], dim=0)
        Target_q = Rewards + self._gamma * (1 - terminals) * (Next_q - self._alpha * next_log_probs)
        # print(Next_q - self._alpha * next_log_probs)
        # print(Next_q - self._alpha * Next_log_probs) 
        # print((Next_q - self._alpha * Next_log_probs).mean(dim=(1, 2))) 
        # print((Next_q - self._alpha * next_log_probs).mean(dim=(1, 2))) 
        # print(Next_log_probs.shape) 
        # sys.exit()
        target_q = Target_q.min(0)[0]  #min
        # target_q = Target_q.mean(0)[0]  #mean
        # print(Target_q)
        # print(target_q)
        # sys.exit()
        # target_q: [batch_size, 1]
        # Target_q = Rewards + self._gamma * (1 - terminals) * Next_q
        # target_q = rewards + self._gamma * (1 - terminals) * next_q
        # qs: [num_critics, batch_size, 1]
        qs = self.critics(obss, actions)
        critics_loss = ((qs - target_q.unsqueeze(0)).pow(2)).mean(dim=(1, 2)).sum()
        qs_loss = 1.0 * ((qs - target_q.unsqueeze(0)).pow(2)).mean(dim=(1, 2)).sum()
        # print(Target_q)
        # print(Target_q.min(0))
        # print(Target_q.min(0)[0])
        # print(Next_q)
        # print(next_log_probs)
        # print(Next_log_probs)
        # print(critics_loss)
        # sys.exit()
        

        
        # part of cql
        # compute conservative loss
        obs_pi_value_list, next_obs_pi_value_list, random_value_list = [[None] * self._num_critics for _ in range(3)]
        q_list, cat_q_list, conservative_loss_list = [[None] * self._num_critics for _ in range(3)]
        random_actions = torch.FloatTensor(
            batch_size * self._num_repeat_actions, actions.shape[-1]
        ).uniform_(self.action_space.low[0], self.action_space.high[0]).to(self.actor.device)
        # tmp_obss & tmp_next_obss: (batch_size * num_repeat, obs_dim)
        tmp_obss = obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        tmp_next_obss = next_obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        # calculate 
        obs_pi_value = self.calc_pi_value(tmp_obss, tmp_obss)
        next_obs_pi_value = self.calc_pi_value(tmp_next_obss, tmp_obss)
        random_value = self.calc_random_value(tmp_obss, random_actions)    
        for value in [
                obs_pi_value, next_obs_pi_value, random_value
            ]:
                value.reshape(self._num_critics * batch_size, self._num_repeat_actions, 1)
        con_qs = self.critics(obss, actions)
        cat_qs = torch.cat([obs_pi_value, next_obs_pi_value, random_value], 2) 
        
        #ecql
        conservative_loss = \
                torch.logsumexp(cat_qs / self._temperature, dim=2).mean() * self._cql_weight * self._temperature - \
            con_qs.mean() * self._cql_weight  
            
        #maxcql
        # cql_loss = torch.logsumexp(cat_qs / self._temperature, dim=2).mean(dim=1) * self._cql_weight * self._temperature - \
        #     con_qs.sum(dim=2).mean(dim=1) * self._cql_weight  
        # conservative_loss = cql_loss.max(0)[0]
        
        
        # print(cql_loss)
        # print(conservative_loss)
        # sys.exit()
        
        # testa = torch.logsumexp(cat_qs / self._temperature, dim=2)
        # testb = con_qs
        # # print(cat_qs.shape)
        # # print(testa.mean(dim=1))
        # print(testa.mean(dim=1).shape)
        # print(testb.mean(dim=1).shape)
        # sys.exit()
        # print(con_qs)
        # print(cat_qs)
        # print(random_value)
        # print(obs_pi_value.mean())
        # print(next_obs_pi_value.mean())
        # print(random_value.mean())
        # # print(cat_qs.mean(dim=1,2).shape)
        # print(torch.logsumexp(cat_qs / self._temperature, dim=2).mean(dim=1).shape)
        # # print(torch.logsumexp(cat_qs / self._temperature, dim=2))
        # # print(torch.logsumexp(cat_qs / self._temperature, dim=2).mean())
        # print(con_qs.sum(dim=2).mean(dim=1).shape)
        # max_cql = (torch.logsumexp(cat_qs / self._temperature, dim=2).mean(dim=1) - con_qs.sum(dim=2).mean(dim=1))
        # print(max_cql)
        # print(max_cql.max(0)[0])
        # sys.exit()
        cql_mean = (torch.logsumexp(cat_qs / self._temperature, dim=2).mean(dim=1) - con_qs.sum(dim=2).mean(dim=1)).mean()
        cql_std = (torch.logsumexp(cat_qs / self._temperature, dim=2).mean(dim=1) - con_qs.sum(dim=2).mean(dim=1)).std()
        # print(test)
        # print(test.std())
        # sys.exit()
        # print(con_qs.size())
        # print(cat_qs.size())
        # print(cat_qs.shape)
        # print(torch.logsumexp(cat_qs / self._temperature, dim=2).shape)
        # print(torch.logsumexp(cat_qs / self._temperature, dim=2).mean().shape)
        # sys.exit()
        # print(torch.logsumexp(cat_qs / self._temperature, dim=2).mean() * self._cql_weight * self._temperature)
        # print(con_qs.mean() * self._cql_weight)
        
        # conservative_loss = \
        #         (cat_qs / self._temperature).sum(dim=2).mean() * self._cql_weight * self._temperature - \
        #     con_qs.mean() * self._cql_weight       
        
        # 原代码
        # result = torch.logsumexp(cat_qs / self._temperature, dim=2).mean()
        # print(result)
        # # 修改后的代码
        # result = (cat_qs / self._temperature).sum(dim=2).mean()
        # print(result)
        # sys.exit()
        
        
        # 遍历
        # for i in range(self._num_critics):
        #     self.critic_list[i].backbone.model[0].weight.data.copy_(
        #         self.critics.model[0].weight.data[i].T)
        #     obs_pi_value_list[i] = self.calc_pi_values(i, tmp_obss, tmp_obss)
        #     next_obs_pi_value_list[i] = self.calc_pi_values(i, tmp_next_obss, tmp_obss)
        #     random_value_list[i] = self.calc_random_values(i, tmp_obss, random_actions)
        #     for value in [
        #         obs_pi_value_list[i], next_obs_pi_value_list[i], random_value_list[i]
        #     ]:
        #         value.reshape(batch_size, self._num_repeat_actions, 1)
            
        #     q_list[i] = self.critic_list[i](obss, actions)
        #     cat_q_list[i] = torch.cat([obs_pi_value_list[i], next_obs_pi_value_list[i], random_value_list[i]], 1)
        #     conservative_loss_list[i] = \
        #         torch.logsumexp(cat_q_list[i] / self._temperature, dim=1).mean() * self._cql_weight * self._temperature - \
        #     q_list[i].mean() * self._cql_weight
        
        # conservative_loss *= self._num_critics  # 18
        conservative_loss = conservative_loss   # 19
        # conservative_loss /= (self._num_critics - 1)   # 20
        critics_loss += conservative_loss # sum(conservative_loss_list) / (self._num_critics - 1) # / (self._num_critics)
        

        
        # 均匀采样
        if self._eta > 0:
            obss_tile = obss.unsqueeze(0).repeat(self._num_critics, 1, 1).requires_grad_(True)
            actions_tile = actions.unsqueeze(0).repeat(self._num_critics, 1, 1).requires_grad_(True)
            qs_preds_tile = self.critics(obss_tile, actions_tile)
            # qs_pred_grads, = torch.autograd.grad(qs_preds_tile.sum(), obss_tile, retain_graph=True, create_graph=True)
            # 分别对obss_tile和actions_tile计算偏导数
            qs_pred_grads_obss, = torch.autograd.grad(qs_preds_tile.sum(), obss_tile, retain_graph=True, create_graph=True)
            qs_pred_grads_actions, = torch.autograd.grad(qs_preds_tile.sum(), actions_tile, retain_graph=True, create_graph=True)
            # 调用actforward函数，假设它能处理批次化的输入
            a_tile, log_probs_tile = self.actforward(obss_tile)
            # 计算批次化的a_tile对obss_tile的梯度
            a_tile_grads, = torch.autograd.grad(outputs=a_tile.sum(), inputs=obss_tile, retain_graph=True, create_graph=True)
            # 将两个偏导数求和  s+pi*a
            # qs_pred_grads = qs_pred_grads_obss + (qs_pred_grads_actions.sum(dim=2, keepdim=True)) * a_tile_grads
            # only s
            qs_pred_grads = qs_pred_grads_obss
            
            qs_pred_grads = qs_pred_grads / (torch.norm(qs_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)
            qs_pred_grads = qs_pred_grads.transpose(0, 1)
            
            qs_pred_grads = torch.einsum('bik,bjk->bij', qs_pred_grads, qs_pred_grads)
            masks = torch.eye(self._num_critics, device=obss.device).unsqueeze(dim=0).repeat(qs_pred_grads.size(0), 1, 1)
            qs_pred_grads = (1 - masks) * qs_pred_grads
            grad_loss = self._eta * torch.mean(torch.sum(qs_pred_grads, dim=(1, 2))) / (self._num_critics - 1)

            critics_loss +=  grad_loss

        self.critics_optim.zero_grad()
        critics_loss.backward()
        self.critics_optim.step()

        self._sync_weight()


        result =  {
            "loss/actor": actor_loss.item(),
            "loss/critics": critics_loss.item(),
            "loss/actor_qs": actor_qs_loss.item(),
            "loss/actor_grad": actor_grad_loss.item(),
            "loss/qs": qs_loss.item(),
            "loss/conservative": conservative_loss.item(),
            "loss/grad": grad_loss.item(),
            "loss/cql_mean": cql_mean.item(),
            "loss/cql_std": cql_std.item(),
        }

        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()
         
        # print("over")
        # sys.exit()
        return result



