import numpy as np
import torch
import torch.nn as nn
import gym
import sys

from torch.distributions import Normal, TanhTransform, TransformedDistribution
from copy import deepcopy
from typing import Dict, Union, Tuple
from offlinerlkit.policy import BasePolicy




class IQL2Policy(BasePolicy):
    """
    Implicit Q-Learning <Ref: https://arxiv.org/abs/2110.06169>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        critic_v: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        critic_v_optim: torch.optim.Optimizer,
        action_space: gym.spaces.Space,
        tau: float = 0.005,
        gamma: float  = 0.99,
        expectile: float = 0.8,
        temperature: float = 0.1
    ) -> None:
        super().__init__()

        self.actor = actor
        self.critic1, self.critic1_old = critic1, deepcopy(critic1)
        self.critic1_old.eval()
        self.critic2, self.critic2_old = critic2, deepcopy(critic2)
        self.critic2_old.eval()
        self.critic_v = critic_v

        self.actor_optim = actor_optim
        self.critic1_optim = critic1_optim
        self.critic2_optim = critic2_optim
        self.critic_v_optim = critic_v_optim

        self.action_space = action_space
        self._tau = tau
        self._gamma = gamma
        self._expectile = expectile
        self._temperature = temperature
        
        self.actor_global =  deepcopy(actor)
        self.alpha1 = 0.6
        self.alpha2 = 6.0

    def train(self) -> None:
        self.actor.train()
        self.critic1.train()
        self.critic2.train()
        self.critic_v.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critic1.eval()
        self.critic2.eval()
        self.critic_v.eval()
        
    def get_v(self) -> None:
        return self.v

    def _sync_weight(self) -> None:
        for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
    
    def select_action_global(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        if len(obs.shape) == 1:
            obs = obs.reshape(1, -1)
        with torch.no_grad():
            dist = self.actor_global(obs)
            if deterministic:
                action = dist.mode().cpu().numpy()
            else:
                action = dist.sample().cpu().numpy()
        action = np.clip(action, self.action_space.low[0], self.action_space.high[0])
        return action

    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:
        if len(obs.shape) == 1:
            obs = obs.reshape(1, -1)
        with torch.no_grad():
            dist = self.actor(obs)
            if deterministic:
                action = dist.mode().cpu().numpy()
            else:
                action = dist.sample().cpu().numpy()
        action = np.clip(action, self.action_space.low[0], self.action_space.high[0])
        return action
    
    
    
    def _expectile_regression(self, diff: torch.Tensor) -> torch.Tensor:
        weight = torch.where(diff > 0, self._expectile, (1 - self._expectile))
        return weight * (diff**2)
    
    def learn(self, batch: Dict) -> Dict[str, float]:
        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        
        # update value net
        with torch.no_grad():
            q1, q2 = self.critic1_old(obss, actions), self.critic2_old(obss, actions)
            q = torch.min(q1, q2)
        v = self.critic_v(obss)
        self.v = v
        critic_v_loss = self._expectile_regression(q-v).mean()
        self.critic_v_optim.zero_grad()
        critic_v_loss.backward()
        self.critic_v_optim.step()

        # update critic
        q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions)
        with torch.no_grad():
            next_v = self.critic_v(next_obss)
            target_q = rewards + self._gamma * (1 - terminals) * next_v
        
        critic1_loss = ((q1 - target_q).pow(2)).mean()
        critic2_loss = ((q2 - target_q).pow(2)).mean()

        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()

        # update actor
        with torch.no_grad():
            q1, q2 = self.critic1_old(obss, actions), self.critic2_old(obss, actions)
            q = torch.min(q1, q2)
            v = self.critic_v(obss)
            exp_a = torch.exp((q - v) * self._temperature)
            exp_a = torch.clip(exp_a, None, 100.0)
            
            
        dist = self.actor(obss)
        log_probs = dist.log_prob(actions)
        # a_global = self.actor_global(obss).mode()
        a_global = self.actor_global(obss).mode()
        log_probs_global = dist.log_prob(a_global)
        a = dist.mode()
       
        # TVD
        logp_prob_local = self.actor.get_log_density(obss, actions)
        server_actions = self.actor_global(obss).mode()
        logp_prob_server = self.actor.get_log_density(obss, server_actions)

        # actor loss
        a_loss = - (exp_a * log_probs).mean()
        b_loss = - (self.alpha1 * logp_prob_local).mean()
        c_loss =  - ((1 - self.alpha1) * logp_prob_server).mean()
        actor_loss = a_loss + b_loss + c_loss
        
        # actor_loss = -(exp_a * log_probs).mean() + self.alpha1 * ((a_global - a).pow(2)).mean() + self.alpha2 * ((actions - a).pow(2)).mean()
        # a_loss = -(exp_a * log_probs).mean()
        # b_loss = self.alpha1 * ((a_global - a).pow(2)).mean()
        # c_loss = self.alpha2 * ((actions - a).pow(2)).mean()
        # self.alpha1 = (1/2) * abs(a_loss / b_loss).item() * self.alpha1
        # self.alpha2 = (1/2) * abs(a_loss / c_loss).item() * self.alpha2
        
        #eval q&v
        eval_q = q
        eval_v = v
        eval_qv = q - v
        
        self.actor_optim.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1, norm_type=2)
        self.actor_optim.step()

        self._sync_weight()

        return {
            "loss/actor": actor_loss.item(),
            "loss/q1": critic1_loss.item(),
            "loss/q2": critic2_loss.item(),
            "loss/v": critic_v_loss.item(),
            "loss/a": a_loss.item(),
            "loss/b": b_loss.item(),
            "value/q": eval_q.mean().item(),
            "value/v": eval_v.mean().item(),
            "value/qv": eval_qv.mean().item(),
            "loss/c": c_loss.item(),
        }
        
        
    def get_log_density(self, observations, action):
        base_network_output = self.backbone(observations)
        mean, log_std = torch.split(base_network_output, self.action_dim, dim=-1)
        log_std = self.log_std_multiplier() * log_std + self.log_std_offset()
        log_std = torch.clamp(log_std, self.tanh_gaussian.log_std_min, self.tanh_gaussian.log_std_max)
        std = torch.exp(log_std)
        action_distribution = TransformedDistribution(
            Normal(mean, std), TanhTransform(cache_size=1)
        )
        action_clip = torch.clip(action, -1. + EPS, 1. - EPS)
        logp_prob = torch.sum(action_distribution.log_prob(action_clip), dim=-1)

        return logp_prob
    
    def log_prob(self, action, raw_action=None):
        if raw_action is None:
            raw_action = self.arctanh(action)
        log_prob = super().log_prob(raw_action).sum(-1, keepdim=True)
        eps = 1e-6
        log_prob = log_prob - torch.log((1 - action.pow(2)) + eps).sum(-1, keepdim=True)
        return log_prob