import numpy as np
import torch
import torch.nn as nn
import gym

from torch.nn import functional as F
from typing import Dict, Union, Tuple
from offlinerlkit.policy import SACPolicy
from typing import Optional, Sequence, cast
import pickle

class DWCQLPolicy(SACPolicy):
    """
    Conservative Q-Learning <Ref: https://arxiv.org/abs/2006.04779>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        discriminator: nn.Module,
        discriminator_optim: torch.optim.Optimizer,
        weight_type: str,
        action_space: gym.spaces.Space,
        tau: float = 0.005,
        gamma: float  = 0.99,
        alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
        cql_weight: float = 1.0,
        temperature: float = 1.0,
        max_q_backup: bool = False,
        deterministic_backup: bool = True,
        with_lagrange: bool = True,
        lagrange_threshold: float = 10.0,
        cql_alpha_lr: float = 1e-4,
        num_repeart_actions:int = 10,
        task: str = '',
    ) -> None:
        super().__init__(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            alpha=alpha
        )

        self.action_space = action_space
        self._cql_weight = cql_weight
        self._temperature = temperature
        self._max_q_backup = max_q_backup
        self._deterministic_backup = deterministic_backup
        self._with_lagrange = with_lagrange
        self._lagrange_threshold = lagrange_threshold

        self.cql_log_alpha = torch.zeros(1, requires_grad=True, device=self.actor.device)
        self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)

        self._num_repeat_actions = num_repeart_actions
        
        self._discriminator = discriminator
        self._discriminator_optim = discriminator_optim
        self._weight_type = weight_type
        self.discriminator_discount = 1.0
        self.discriminator_weight_temp = 1.0
        self.discriminator_kl_penalty_coef  = 0.
        self.discriminator_flow_coef = 0.1

        self.task = task
        

    def calc_pi_values(
        self,
        obs_pi: torch.Tensor,
        obs_to_pred: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        act, log_prob = self.actforward(obs_pi)

        q1 = self.critic1(obs_to_pred, act)
        q2 = self.critic2(obs_to_pred, act)

        return q1 - log_prob.detach(), q2 - log_prob.detach()

    def calc_random_values(
        self,
        obs: torch.Tensor,
        random_act: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q1 = self.critic1(obs, random_act)
        q2 = self.critic2(obs, random_act)

        log_prob1 = np.log(0.5**random_act.shape[-1])
        log_prob2 = np.log(0.5**random_act.shape[-1])

        return q1 - log_prob1, q2 - log_prob2


    def _orthogonal_regularization(self, network):
        reg = 0
        # for layer in network.layers:
        #   if isinstance(layer, tf.keras.layers.Dense):
        #     prod = tf.matmul(tf.transpose(layer.kernel), layer.kernel)
        #     reg += tf.reduce_sum(tf.math.square(prod * (1 - tf.eye(prod.shape[0]))))
        for k,v in network.named_parameters():
            # print(k,v.shape)
            if 'weight' in k:
                # print(k,v.shape)
                prod = torch.mm(v, v.T)
                reg += (torch.square(prod * (1 - torch.eye(prod.shape[0],device=v.device)))).sum()
        #print('reg',reg.shape)
        
        return reg


    def _update_discriminator(self, 
        batch: Dict) -> np.ndarray:
        assert self._discriminator_optim is not None
    
        self._discriminator_optim.zero_grad()
        normalized_dsa_ratios, normalized_ds_ratios, logits_sa, logits_s, logits_next_s = self._discriminator.compute_normalized_ratio_and_logits(
          batch['observations'], batch['actions'], batch['next_observations'])
        
        # normalized_rewards = self._discriminator_reward_scaler.transform(batch.rewards)
        normalized_rewards = batch['rewards']
        reward_loss = -(normalized_dsa_ratios * normalized_rewards).sum()
        kl_loss = (normalized_ds_ratios * torch.log(normalized_ds_ratios)).sum()
        flow_loss = torch.square(
                self.discriminator_discount * torch.exp((logits_sa + logits_s) / self.discriminator_weight_temp) - 
                torch.exp(logits_next_s / self.discriminator_weight_temp)).mean()
        
        loss = reward_loss + \
            self.discriminator_kl_penalty_coef * kl_loss + \
            self.discriminator_flow_coef * flow_loss
        
        loss.backward()
        self._discriminator_optim.step()

        return normalized_dsa_ratios, normalized_ds_ratios, {
            "discriminator_loss": loss.cpu().detach().numpy(),
            "reward_loss": reward_loss.cpu().detach().numpy(),
            "kl_loss": kl_loss.cpu().detach().numpy(),
            "flow_loss": flow_loss.cpu().detach().numpy(),
        }



    def learn(self, init_batch: Dict, batch: Dict)-> Dict[str, float]:
        
        normalized_dsa_ratios, normalized_ds_ratios, discriminator_losses = self._update_discriminator(batch)

        dsa_weights = normalized_dsa_ratios.detach()
        ds_weights = normalized_ds_ratios.detach()     
        weights = dsa_weights   
    
        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        # init_obss = init_batch['init_observations']
        batch_size = obss.shape[0]

        # compute the weight


        # update actor
        a, log_probs = self.actforward(obss)
        q1a, q2a = self.critic1(obss, a), self.critic2(obss, a)
        actor_loss =( (self._alpha * log_probs - torch.min(q1a, q2a)) * weights).mean()/weights.mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

        if self._is_auto_alpha:
            log_probs = log_probs.detach() + self._target_entropy
            alpha_loss = -(self._log_alpha * log_probs).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            self._alpha = self._log_alpha.detach().exp()
        
        # compute td error
        if self._max_q_backup:
            with torch.no_grad():
                tmp_next_obss = next_obss.unsqueeze(1) \
                    .repeat(1, self._num_repeat_actions, 1) \
                    .view(batch_size * self._num_repeat_actions, next_obss.shape[-1])
                tmp_next_actions, _ = self.actforward(tmp_next_obss)
                tmp_next_q1 = self.critic1_old(tmp_next_obss, tmp_next_actions) \
                    .view(batch_size, self._num_repeat_actions, 1) \
                    .max(1)[0].view(-1, 1)
                tmp_next_q2 = self.critic2_old(tmp_next_obss, tmp_next_actions) \
                    .view(batch_size, self._num_repeat_actions, 1) \
                    .max(1)[0].view(-1, 1)
                next_q = torch.min(tmp_next_q1, tmp_next_q2)
        else:
            with torch.no_grad():
                next_actions, next_log_probs = self.actforward(next_obss)
                next_q = torch.min(
                    self.critic1_old(next_obss, next_actions),
                    self.critic2_old(next_obss, next_actions)
                )
                if not self._deterministic_backup:
                    next_q -= self._alpha * next_log_probs

        target_q = rewards + self._gamma * (1 - terminals) * next_q
        q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions)
        critic1_loss = ((q1 - target_q).pow(2) * weights).mean()/(weights.mean())
        critic2_loss = ((q2 - target_q).pow(2) * weights).mean()/(weights.mean())

        # compute conservative loss
        random_actions = torch.FloatTensor(
            batch_size * self._num_repeat_actions, actions.shape[-1]
        ).uniform_(self.action_space.low[0], self.action_space.high[0]).to(self.actor.device)
        # tmp_obss & tmp_next_obss: (batch_size * num_repeat, obs_dim)
        tmp_obss = obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        tmp_next_obss = next_obss.unsqueeze(1) \
            .repeat(1, self._num_repeat_actions, 1) \
            .view(batch_size * self._num_repeat_actions, obss.shape[-1])
        
        obs_pi_value1, obs_pi_value2 = self.calc_pi_values(tmp_obss, tmp_obss)
        next_obs_pi_value1, next_obs_pi_value2 = self.calc_pi_values(tmp_next_obss, tmp_obss)
        random_value1, random_value2 = self.calc_random_values(tmp_obss, random_actions)

        for value in [
            obs_pi_value1, obs_pi_value2, next_obs_pi_value1, next_obs_pi_value2,
            random_value1, random_value2
        ]:
            value.reshape(batch_size, self._num_repeat_actions, 1)
        
        # cat_q shape: (batch_size, 3 * num_repeat, 1)
        cat_q1 = torch.cat([obs_pi_value1, next_obs_pi_value1, random_value1], 1)
        cat_q2 = torch.cat([obs_pi_value2, next_obs_pi_value2, random_value2], 1)

        conservative_loss1 = \
            (torch.logsumexp(cat_q1 / self._temperature, dim=1) * weights).mean()/weights.mean() * self._cql_weight * self._temperature - \
            (q1*weights).mean() /weights.mean() * self._cql_weight
        conservative_loss2 = \
            (torch.logsumexp(cat_q2 / self._temperature, dim=1) * weights).mean() /weights.mean()* self._cql_weight * self._temperature - \
            (q2 * weights).mean()/weights.mean()* self._cql_weight  
        
        if self._with_lagrange:
            cql_alpha = torch.clamp(self.cql_log_alpha.exp(), 0.0, 1e6)
            conservative_loss1 = cql_alpha * (conservative_loss1 - self._lagrange_threshold)
            conservative_loss2 = cql_alpha * (conservative_loss2 - self._lagrange_threshold)

            self.cql_alpha_optim.zero_grad()
            cql_alpha_loss = -(conservative_loss1 + conservative_loss2) * 0.5
            cql_alpha_loss.backward(retain_graph=True)
            self.cql_alpha_optim.step()
        
        critic1_loss = critic1_loss + conservative_loss1
        critic2_loss = critic2_loss + conservative_loss2

        # update critic
        self.critic1_optim.zero_grad()
        critic1_loss.backward(retain_graph=True)
        self.critic1_optim.step()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()

        self._sync_weight()

        result =  {
            "loss/actor": actor_loss.item(),
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
            'indicators/weights_m': weights.mean(),
            'indicators/weights_max': weights.max(),
            'indicators/weights_min': weights.min()
        }

        if self._is_auto_alpha:
            result["loss/alpha"] = alpha_loss.item()
            result["alpha"] = self._alpha.item()
        if self._with_lagrange:
            result["loss/cql_alpha"] = cql_alpha_loss.item()
            result["cql_alpha"] = cql_alpha.item()
        
        return result

