from algos.common.critic_preference import CriticSAPreference
from algos.common.critic_multihead import CriticSAMultiHead
from algos.common.network_base import MLP

import numpy as np
import torch

class CriticSADist(CriticSAMultiHead):
    def __init__(self, device:torch.device, state_dim:int, action_dim:int, 
                 reward_dim:int, critic_cfg:dict) -> None:
        self.n_critics = critic_cfg['n_critics']
        self.n_quantiles = critic_cfg['n_quantiles']
        self.reward_dim = reward_dim
        super().__init__(device, state_dim, action_dim, reward_dim, critic_cfg)

    def build(self) -> None:
        activation_name = self.critic_cfg['mlp']['activation']
        self.activation = eval(f'torch.nn.{activation_name}')
        for reward_idx in range(self.reward_dim):
            for critic_idx in range(self.n_critics):
                self.add_module(f"reward{reward_idx}_critic{critic_idx}", MLP(
                    input_size=self.state_dim + self.action_dim, 
                    output_size=self.n_quantiles,
                    shape=self.critic_cfg['mlp']['shape'], 
                    activation=self.activation,
                ))

        for item_idx in range(len(self.critic_cfg['clip_range'])):
            item = self.critic_cfg['clip_range'][item_idx]
            if type(item) == str:
                self.critic_cfg['clip_range'][item_idx] = eval(item)
        self.clip_range = self.critic_cfg['clip_range']

        # calculate cdf
        with torch.no_grad():
            cdf = (torch.arange(
                self.n_quantiles, device=self.device, dtype=torch.float32)
                + 0.5)/self.n_quantiles
            self.cdf = cdf.view(1, 1, 1, -1, 1) # 1 x 1 x M x 1

    def forward(self, state:torch.Tensor, action:torch.Tensor) -> torch.Tensor:
        '''
        outputs: 
            batch_size x reward_dim x n_critics x n_quantiles
            or 
            reward_dim x n_critics x n_quantiles
        '''
        concat_x = torch.cat([state, action], dim=-1)
        critics = []
        for reward_idx in range(self.reward_dim):
            quantiles = []
            for critic_idx in range(self.n_critics):
                x = self._modules[f"reward{reward_idx}_critic{critic_idx}"](concat_x)
                x = torch.clamp(x, self.clip_range[0], self.clip_range[1])
                quantiles.append(x)
            x = torch.stack(quantiles, dim=-2)
            critics.append(x)
        x = torch.stack(critics, dim=-3)
        return x

    def getLoss(self, state:torch.Tensor, action:torch.Tensor, target:torch.Tensor) -> torch.Tensor:
        '''
        state: batch_size x state_dim
        action: batch_size x action_dim
        target: batch_size x reward_dim x n_critics x n_quantiles
        '''
        batch_size = state.shape[0]
        target = target.view(
            batch_size, self.reward_dim, 1, 1,
            self.n_critics*self.n_quantiles)
        current_quantiles = self.forward(state, action).unsqueeze(-1)
        pairwise_delta = target - current_quantiles # B x R x N x M x kN

        # without huber loss
        critic_loss = torch.mean(
            pairwise_delta*(self.cdf - (pairwise_delta.detach() < 0).float()))

        # with huber loss
        # abs_pairwise_delta = torch.abs(pairwise_delta)
        # huber_terms = torch.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5)
        # critic_loss = torch.mean(torch.abs(self.cdf - (pairwise_delta.detach() < 0).float()) * huber_terms)

        return critic_loss
    

class CriticSAPreferDist(CriticSAPreference):
    def __init__(self, device:torch.device, state_dim:int, action_dim:int, 
                 reward_dim:int, preference_dim:int, critic_cfg:dict) -> None:
        self.n_critics = critic_cfg['n_critics']
        self.n_quantiles = critic_cfg['n_quantiles']
        super().__init__(
            device, state_dim, action_dim, reward_dim, 
            preference_dim, critic_cfg)

    def build(self) -> None:
        activation_name = self.critic_cfg['mlp']['activation']
        self.activation = eval(f'torch.nn.{activation_name}')
        normalization = self.critic_cfg['mlp'].get('normalization', None)
        for reward_idx in range(self.reward_dim):
            for critic_idx in range(self.n_critics):
                self.add_module(f"reward{reward_idx}_critic{critic_idx}", MLP(
                    input_size=self.state_dim + self.action_dim + self.preference_dim, 
                    output_size=self.n_quantiles,
                    shape=self.critic_cfg['mlp']['shape'], 
                    activation=self.activation,
                    normalization=normalization,
                ))

        # last activation
        last_activation_name = self.critic_cfg['mlp'].get('last_activation', None)
        if last_activation_name is not None:
            self.last_activation = eval(f'torch.nn.{last_activation_name}')()
        else:
            self.last_activation = lambda x:x

        # clip range
        for item_idx in range(len(self.critic_cfg['clip_range'])):
            item = self.critic_cfg['clip_range'][item_idx]
            if type(item) == str:
                self.critic_cfg['clip_range'][item_idx] = eval(item)
        self.clip_range = self.critic_cfg['clip_range']

        # calculate cdf
        with torch.no_grad():
            cdf = (torch.arange(
                self.n_quantiles, device=self.device, dtype=torch.float32)
                + 0.5)/self.n_quantiles
            self.cdf = cdf.view(1, 1, 1, -1, 1) # 1 x 1 x M x 1

    def forward(self, state:torch.Tensor, action:torch.Tensor, preference:torch.Tensor) -> torch.Tensor:
        '''
        outputs: 
            batch_size x reward_dim x n_critics x n_quantiles
            or 
            reward_dim x n_critics x n_quantiles
        '''
        concat_x = torch.cat([state, action, preference], dim=-1)
        critics = []
        for reward_idx in range(self.reward_dim):
            quantiles = []
            for critic_idx in range(self.n_critics):
                x = self._modules[f"reward{reward_idx}_critic{critic_idx}"](concat_x) # (batch_size, n_quantiles)
                x = self.last_activation(x)
                x = torch.clamp(x, self.clip_range[0], self.clip_range[1])
                quantiles.append(x)
            x = torch.stack(quantiles, dim=-2) # (batch_size, n_critics, n_quantiles)
            critics.append(x)
        x = torch.stack(critics, dim=-3)
        return x

    def getLoss(self, state:torch.Tensor, action:torch.Tensor, \
                preference:torch.Tensor, target:torch.Tensor) -> torch.Tensor:
        '''
        state: batch_size x state_dim
        action: batch_size x action_dim
        preference: batch_size x preference_dim
        target: batch_size x reward_dim x n_critics x n_quantiles
        '''
        batch_size = state.shape[0]
        target = target.view(
            batch_size, self.reward_dim, 1, 1, -1)
        current_quantiles = self.forward(state, action, preference).unsqueeze(-1)
        pairwise_delta = target - current_quantiles # B x R x N x M x kN

        # without huber loss
        critic_loss = torch.mean(
            pairwise_delta*(self.cdf - (pairwise_delta.detach() < 0).float()))

        # with huber loss
        # abs_pairwise_delta = torch.abs(pairwise_delta)
        # huber_terms = torch.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5)
        # critic_loss = torch.mean(torch.abs(self.cdf - (pairwise_delta.detach() < 0).float()) * huber_terms)

        return critic_loss
