from algos.common.network_base import MLP, initWeights
from algos.common.actor_base import (
    ActorBase, unnormalize
)

from typing import Tuple
import numpy as np
import torch

EPS = 1e-8


class Actor(ActorBase):
    def __init__(self, device:torch.device, state_dim:int, action_dim:int, reward_dim:int, \
                action_bound_min:np.ndarray, action_bound_max:np.ndarray, actor_cfg:dict) -> None:
        ActorBase.__init__(self, device)

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.reward_dim = reward_dim
        self.action_bound_min = torch.tensor(
            action_bound_min, device=device, dtype=torch.float32
        )
        self.action_bound_max = torch.tensor(
            action_bound_max, device=device, dtype=torch.float32
        )
        self.actor_cfg = actor_cfg

        # build model
        self.build()

    def build(self) -> None:
        activation_name = self.actor_cfg['mlp']['activation']
        self.activation = eval(f'torch.nn.{activation_name}')
        self.add_module('model', MLP(
            input_size=(self.state_dim + self.reward_dim), output_size=self.actor_cfg['mlp']['shape'][-1], \
            shape=self.actor_cfg['mlp']['shape'][:-1], activation=self.activation,
        ))
        self.add_module("mean_decoder", torch.nn.Sequential(
            self.activation(),
            torch.nn.Linear(self.actor_cfg['mlp']['shape'][-1], self.action_dim),
        ))
        
    def forward(self, state:torch.Tensor, preference:torch.Tensor) -> torch.Tensor:
        '''
        output: (mean,)
        '''
        x = torch.cat([state, preference], dim=-1)
        x = self.model(x)
        mean = torch.tanh(self.mean_decoder(x))
        return mean

    def updateActionDist(self, state:torch.Tensor, preference:torch.Tensor, noise_scale:float) -> None:
        self.action_mean = self.forward(state, preference)
        self.action_std = torch.ones_like(self.action_mean) * noise_scale
        self.action_dist = torch.distributions.Normal(self.action_mean, self.action_std)
        self.normal_action = self.action_dist.sample()

    def sample(self, deterministic:bool=False) -> Tuple[torch.Tensor, torch.Tensor]:
        if deterministic:
            norm_action = self.action_mean
        else:
            norm_action = torch.clamp(self.normal_action, -1.0, 1.0)
        unnorm_action = unnormalize(norm_action, self.action_bound_min, self.action_bound_max)
        return norm_action, unnorm_action
    
    def getDist(self) -> torch.distributions.Distribution:
        return self.action_dist
        
    def getEntropy(self) -> torch.Tensor:
        '''
        return entropy of action distribution given state.
        '''
        entropy = torch.mean(torch.sum(self.action_dist.entropy(), dim=-1))
        return entropy
    
    def getLogProb(self) -> torch.Tensor:
        '''
        return log probability of action given state.
        '''
        log_prob = torch.sum(self.action_dist.log_prob(self.normal_action), dim=-1)
        return log_prob

    def initialize(self) -> None:
        for name, module in self.named_children():
            module.apply(initWeights)
