from algos.common.critic_multihead import CriticSAMultiHead
from algos.common.network_base import MLP

import numpy as np
import torch

class CriticSAPreference(CriticSAMultiHead):
    def __init__(self, device:torch.device, state_dim:int, action_dim:int, 
                 reward_dim:int, preference_dim:int, critic_cfg:dict) -> None:
        self.reward_dim = reward_dim
        self.preference_dim = preference_dim
        super().__init__(device, state_dim, action_dim, reward_dim, critic_cfg)

    def build(self) -> None:
        activation_name = self.critic_cfg['mlp']['activation']
        self.activation = eval(f'torch.nn.{activation_name}')
        self.no_share = self.critic_cfg.get('no_share', False)
        if self.no_share:
            for reward_idx in range(self.reward_dim):
                self.add_module(f'model_{reward_idx}', MLP(
                    input_size=self.state_dim + self.action_dim + self.preference_dim, output_size=1,
                    shape=self.critic_cfg['mlp']['shape'], activation=self.activation,
                ))
        else:
            self.add_module('model', MLP(
                input_size=self.state_dim + self.action_dim + self.preference_dim, output_size=self.reward_dim,
                shape=self.critic_cfg['mlp']['shape'], activation=self.activation,
            ))
        for item_idx in range(len(self.critic_cfg['clip_range'])):
            item = self.critic_cfg['clip_range'][item_idx]
            if type(item) == str:
                self.critic_cfg['clip_range'][item_idx] = eval(item)
        self.clip_range = self.critic_cfg['clip_range']

    def forward(self, state:torch.Tensor, action:torch.Tensor, preference:torch.Tensor) -> torch.Tensor:
        x = torch.cat([state, action, preference], dim=-1)
        if self.no_share:
            x = torch.cat([self._modules[f'model_{reward_idx}'](x) for reward_idx in range(self.reward_dim)], dim=-1)
        else:
            x = self.model(x)
        x = torch.clamp(x, self.clip_range[0], self.clip_range[1])
        return x

    def getLoss(self, state:torch.Tensor, action:torch.Tensor, \
                preference:torch.Tensor, target:torch.Tensor, type:str='mse_loss') -> torch.Tensor:
        if type == 'smooth_l1_loss':
            return torch.nn.functional.smooth_l1_loss(self.forward(state, action, preference), target)
        elif type == 'mse_loss':
            return torch.nn.functional.mse_loss(self.forward(state, action, preference), target)
