import numpy as np
import torch
import torch.nn as nn

from torch.nn import functional as F
from typing import Dict, Union, Tuple, Callable
from offlinerlkit.policy import TD3Policy
from offlinerlkit.utils.noise import GaussianNoise
from offlinerlkit.utils.scaler import StandardScaler
import pickle

class OSDTD3BCPolicy(TD3Policy):
    """
    TD3+BC <Ref: https://arxiv.org/abs/2106.06860>
    """

    def __init__(
        self,
        actor: nn.Module,
        critic1: nn.Module,
        critic2: nn.Module,
        actor_optim: torch.optim.Optimizer,
        critic1_optim: torch.optim.Optimizer,
        critic2_optim: torch.optim.Optimizer,
        nu_network: nn.Module,
        nu_optimizer: torch.optim.Optimizer,
        lam_v:torch.Tensor,
        lam_v_optimizer: torch.optim.Optimizer,
        osd_alpha: float,
        lower: float,
        higher: float,
        osd_beta: float,
        weight_type: str,
        dataset_type: str,
        tau: float = 0.005,
        gamma: float  = 0.99,
        max_action: float = 1.0,
        exploration_noise: Callable = GaussianNoise,
        policy_noise: float = 0.2,
        noise_clip: float = 0.5,
        update_actor_freq: int = 2,
        alpha: float = 2.5,
        scaler: StandardScaler = None
    ) -> None:

        super().__init__(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            tau=tau,
            gamma=gamma,
            max_action=max_action,
            exploration_noise=exploration_noise,
            policy_noise=policy_noise,
            noise_clip=noise_clip,
            update_actor_freq=update_actor_freq
        )

        self._alpha = alpha
        self.scaler = scaler
        
        self._nu_network = nu_network
        self._lamb_v = lam_v
        self._osd_alpha = osd_alpha
        self._nu_optimizer = nu_optimizer
        self._lamb_v_optimizer = lam_v_optimizer
        self._lower = lower
        self._higher = higher
        self._osd_beta = osd_beta
        self._weight_type = weight_type
        self._dataset_type = dataset_type

        self.debug_count = 0

        self._lamb_scale = 1.0
        self._v_l2_reg = 1e-4
        # self._osd_beta = 1e-2
        self._f_fn = lambda x: torch.where(x < 1, x * (torch.log(x + 1e-10) - 1) + 1, 0.5 * (x - 1) ** 2)
        zero = torch.zeros(1)
        self._f_prime_inv_fn = lambda x: torch.where(x < 0, torch.exp(torch.minimum(x, zero.to(x.device))), x + 1)
        self._g_fn = lambda x: torch.where(x < 0, torch.exp(torch.minimum(x, zero.to(x.device))) * (torch.minimum(x, zero.to(x.device)) - 1) + 1, 0.5 * x ** 2)
        self._r_fn = lambda x: self._f_prime_inv_fn(x)
        self._log_r_fn = lambda x: torch.where(x < 0, x, torch.log(torch.maximum(x, zero.to(x.device)) + 1))
    
    def train(self) -> None:
        self.actor.train()
        self.critic1.train()
        self.critic2.train()
        self._nu_network.train()

    def eval(self) -> None:
        self.actor.eval()
        self.critic1.eval()
        self.critic2.eval()
        self._nu_network.eval()

    def _sync_weight(self) -> None:
        for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic1_old.parameters(), self.critic1.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
        for o, n in zip(self.critic2_old.parameters(), self.critic2.parameters()):
            o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
    
    def select_action(self, obs: np.ndarray, deterministic: bool = False) -> np.ndarray:

        if self.scaler is not None:
            obs = self.scaler.transform(obs)
        if self._dataset_type == 'dice':
            obs= np.append(obs, 0)
        with torch.no_grad():
            action = self.actor(obs).cpu().numpy()
        if not deterministic:
            action = action + self.exploration_noise(action.shape)
            action = np.clip(action, -self._max_action, self._max_action)
        return action
    


    def _orthogonal_regularization(self, network):
        reg = 0
        # for layer in network.layers:
        #   if isinstance(layer, tf.keras.layers.Dense):
        #     prod = tf.matmul(tf.transpose(layer.kernel), layer.kernel)
        #     reg += tf.reduce_sum(tf.math.square(prod * (1 - tf.eye(prod.shape[0]))))
        for k,v in network.named_parameters():
            # print(k,v.shape)
            if 'weight' in k:
                # print(k,v.shape)
                prod = torch.mm(v, v.T)
                reg += (torch.square(prod * (1 - torch.eye(prod.shape[0],device=v.device)))).sum()
        #print('reg',reg.shape)
        
        return reg


    def v_loss(self, initial_v_values, e_v, w_v, f_w_v, result={}):
        # Compute v loss
        v_loss0 = ((1 - self._gamma) * initial_v_values).mean()
        v_loss1 = - self._osd_alpha * f_w_v.mean()

 
        v_loss2 = (w_v * (e_v - self._lamb_v)).mean()
        v_loss3 = self._lamb_v
        v_loss4 = torch.square(e_v).mean()

        v_loss = v_loss0 + v_loss1 + v_loss2 + v_loss3 + self._osd_beta * v_loss4
        v_l2_norm = self._orthogonal_regularization(self._nu_network)

        if self._v_l2_reg is not None:
            v_loss = v_loss + self._v_l2_reg * v_l2_norm

        result.update({
            'v_loss0': v_loss0.item(),
            'v_loss1': v_loss1.item(),
            'v_loss2': v_loss2.item(),
            'v_loss3': v_loss3.item(),
            'v_loss4':v_loss4.item(),
            'v_loss': v_loss.item(),
            'w_v': w_v.mean().item(),
            'w_v_min': w_v.min().item(),
            'w_v_max': w_v.max().item(),
            'v_l2_norm': v_l2_norm.item()
        })
        return v_loss, result


    def lamb_v_loss(self, e_v, w_v, f_w_v, result={}):
        # GenDICE regularization: E_D[w(s,a)] = 1

        lamb_v_loss = (- self._osd_alpha * f_w_v.detach() + w_v.detach() * (e_v.detach() - self._lamb_scale * self._lamb_v) + self._osd_beta * torch.square(e_v) + self._lamb_v).mean()
        result.update({
            'lamb_v_loss': lamb_v_loss.item(),
            'lamb_v': self._lamb_v.item(),
        })

        return lamb_v_loss, result



    
    def learn_osd(self, initial_batch, batch):

        # save_file_name = 'data/'+ str(self.debug_count) + '.pkl'
        # model_file_nu_name = 'data/model_nu_'+ str(self.debug_count) + '.pkl'
        # model_file_lamb_name = 'data/model_lamb_'+ str(self.debug_count) + '.pkl'     
        # with open(save_file_name, 'rb') as f:
        #     tmp = pickle.load(f)
        #     initial_batch = tmp[0]
        #     batch = tmp[1]

        # with open(model_file_nu_name, 'rb') as f:
        #     loaded_nu_state_dict = pickle.load(f)
        # with open(model_file_lamb_name, 'rb') as f:
        #     self._lamb_v= pickle.load(f)
        # self._nu_network.load_state_dict(loaded_nu_state_dict)
            



        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        init_obss = initial_batch['init_observations']        



        initial_v_values = self._nu_network(init_obss)
        v_values = self._nu_network(obss)
        next_v_values = self._nu_network(next_obss)
        e_v = rewards + (1 - terminals) * self._gamma * next_v_values - v_values
        preactivation_v = (e_v - self._lamb_scale * self._lamb_v) / self._osd_alpha
        w_v = self._r_fn(preactivation_v)
        f_w_v = self._g_fn(preactivation_v)

        v_loss, loss_result = self.v_loss(initial_v_values, e_v, w_v, f_w_v, result={})

        lamb_v_loss, loss_result = self.lamb_v_loss(e_v, w_v, f_w_v, result=loss_result)

        # print('count', self.debug_count, 'v_loss', v_loss, 'lamb_v_loss', lamb_v_loss)
        self._nu_optimizer.zero_grad()  

        self._nu_optimizer.zero_grad()
        v_loss.backward(retain_graph=True)
        self._nu_optimizer.step()


        self._lamb_v_optimizer.zero_grad()
        lamb_v_loss.backward(retain_graph=True)
        self._lamb_v_optimizer.step()
        self._cnt += 1
        self.debug_count += 1

        return loss_result


    
    def learn(self, init_batch: Dict, batch: Dict) -> Dict[str, float]:

        obss, actions, next_obss, rewards, terminals = batch["observations"], batch["actions"], \
            batch["next_observations"], batch["rewards"], batch["terminals"]
        init_obss = init_batch['init_observations']
        
        # compute the weight
        with torch.no_grad():

            optidice_obss = obss
            optidice_next_obss = next_obss
            v_values = self._nu_network(optidice_obss)
            next_v_values = self._nu_network(optidice_next_obss)
            e_v = rewards + (1 - terminals) * self._gamma * next_v_values - v_values
            preactivation_v = (e_v -  self._lamb_v) / self._osd_alpha
            w_v = self._r_fn(preactivation_v)

            
            if self._weight_type == 'median':
                median = w_v.quantile(q=0.5)
                weights = torch.clamp(w_v - median + 1., self._lower, self._higher)
            else:
                weights = torch.clamp(w_v, self._lower, self._higher)
            # weigths = torch.clamp(weights, 1.0-1e-6, 1.0+1e-6)

        
        # update critic
        q1, q2 = self.critic1(obss, actions), self.critic2(obss, actions)
        with torch.no_grad():
            noise = (torch.randn_like(actions) * self._policy_noise).clamp(-self._noise_clip, self._noise_clip)
            next_actions = (self.actor_old(next_obss) + noise).clamp(-self._max_action, self._max_action)
            next_q = torch.min(self.critic1_old(next_obss, next_actions), self.critic2_old(next_obss, next_actions))
            target_q = rewards + self._gamma * (1 - terminals) * next_q
        
        critic1_loss = ((q1 - target_q).pow(2)*weights).mean()/(weights.mean()+1e-6)
        critic2_loss = ((q2 - target_q).pow(2) * weights).mean()/(weights.mean()+1e-6)

        self.critic1_optim.zero_grad()
        critic1_loss.backward()
        self.critic1_optim.step()

        self.critic2_optim.zero_grad()
        critic2_loss.backward()
        self.critic2_optim.step()

        # update actor
        if self._cnt % self._freq == 0:
            a = self.actor(obss)
            q = self.critic1(obss, a)
            lmbda = self._alpha / (q.abs()*weights).mean().detach()/(weights.mean()+1e-6)
            actor_loss = -lmbda * (q*weights).mean()/(weights.mean()+1e-6)+ ((a - actions).pow(2)*weights).mean()/(weights.mean()+1e-6)
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()
            self._last_actor_loss = actor_loss.item()
            self._sync_weight()
        
        self._cnt += 1

        return {
            "loss/actor": self._last_actor_loss,
            "loss/critic1": critic1_loss.item(),
            "loss/critic2": critic2_loss.item(),
            'indicators/weights_m': weights.mean(),
            'indicators/weights_max': weights.max(),
            'indicators/weights_min': weights.min()
        }