from algos.common.network_base import MLP, initWeights
import numpy as np
import torch

class Multiplier(torch.nn.Module):
    def __init__(self, device:torch.device, preference_dim:int, cost_dim:int, model_cfg:dict) -> None:
        torch.nn.Module.__init__(self)

        self.device = device
        self.preference_dim = preference_dim
        self.cost_dim = cost_dim

        # for model
        activation_name = model_cfg['mlp']['activation']
        self.activation = eval(f'torch.nn.{activation_name}')
        self.add_module('model', MLP(
            input_size=self.preference_dim, output_size=self.cost_dim, \
            shape=model_cfg['mlp']['shape'], activation=self.activation,
        ))
        self.init_value = model_cfg['init_value']
        self.log_init_value = np.log(self.init_value)
        for item_idx in range(len(model_cfg['clip_range'])):
            item = model_cfg['clip_range'][item_idx]
            if type(item) == str:
                model_cfg['clip_range'][item_idx] = eval(item)
        self.clip_range = model_cfg['clip_range']
        self.max_value = self.clip_range[1]
        self.log_init_value = np.arctanh(2.0*self.init_value/self.max_value - 1.0)


    def forward(self, preference:torch.Tensor) -> torch.Tensor:
        """
        preference: (batch_size, preference_dim)
        """
        x = self.model(preference)
        x = self.max_value*(1.0 + torch.tanh(x + self.log_init_value*torch.ones_like(x)))/2.0
        return x
    
    def initialize(self) -> None:
        self.apply(initWeights)
