from algos.common.actor_gaussian import ActorGaussian
from algos.common.actor_squash import ActorSquash
from algos.common.network_base import MLP

from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform
from typing import Tuple
import numpy as np
import torch

class ActorGaussianPreference(ActorGaussian):
    def __init__(self, device:torch.device, state_dim:int, action_dim:int, \
                action_bound_min:np.ndarray, action_bound_max:np.ndarray, \
                n_objs:int, actor_cfg:dict, \
                log_std_min:float=-4.0, log_std_max:float=2.0) -> None:
        self.n_objs = n_objs
        ActorGaussian.__init__(
            self, device, state_dim, action_dim, 
            action_bound_min, action_bound_max, actor_cfg, 
            log_std_min, log_std_max)
        
    def build(self) -> None:
        self.add_module('model', MLP(
            input_size=self.state_dim + self.n_objs,
            output_size=self.actor_cfg['mlp']['shape'][-1],
            shape=self.actor_cfg['mlp']['shape'][:-1], 
            activation=self.activation,
        ))
        self.add_module("mean_decoder", torch.nn.Sequential(
            self.activation(),
            torch.nn.Linear(self.actor_cfg['mlp']['shape'][-1], self.action_dim),
        ))
        if self.log_std_fix:
            self.std_decoder = lambda x: torch.ones(
                *x.shape[:-1], self.action_dim, dtype=torch.float, device=self.device)*self.log_std_init
        elif self.log_std_state_cond:
            self.add_module("std_decoder", torch.nn.Sequential(
                self.activation(),
                torch.nn.Linear(self.actor_cfg['mlp']['shape'][-1], self.action_dim),
            ))
        else:
            self.log_std = torch.nn.Parameter(torch.zeros(self.action_dim, dtype=torch.float32, device=self.device, requires_grad=True))
            self.std_decoder = lambda x: self.log_std.expand(x.shape[0], -1)
 
    def forward(self, state:torch.Tensor, preference:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = torch.cat([state, preference], dim=-1)
        x = self.model(x)
        mean = self.last_activation(self.mean_decoder(x))
        log_std = self.std_decoder(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        return mean, log_std, std
    
    def updateActionDist(self, state:torch.Tensor, preference:torch.Tensor, epsilon:torch.Tensor) -> None:
        self.action_mean, self.action_log_std, self.action_std = self.forward(state, preference)
        self.normal_action = self.action_mean + epsilon*self.action_std
        self.action_dist = torch.distributions.Normal(self.action_mean, self.action_std)


class ActorSquashPreference(ActorSquash):
    def __init__(self, device:torch.device, state_dim:int, action_dim:int, \
                action_bound_min:np.ndarray, action_bound_max:np.ndarray, \
                n_objs:int, actor_cfg:dict, \
                log_std_min:float=-4.0, log_std_max:float=2.0) -> None:
        self.n_objs = n_objs
        ActorSquash.__init__(
            self, device, state_dim, action_dim, 
            action_bound_min, action_bound_max, actor_cfg, 
            log_std_min, log_std_max)
        
    def build(self) -> None:
        self.add_module('model', MLP(
            input_size=self.state_dim + self.n_objs,
            output_size=self.actor_cfg['mlp']['shape'][-1],
            shape=self.actor_cfg['mlp']['shape'][:-1], 
            activation=self.activation,
        ))
        self.add_module("mean_decoder", torch.nn.Sequential(
            self.activation(),
            torch.nn.Linear(self.actor_cfg['mlp']['shape'][-1], self.action_dim),
        ))
        if self.log_std_fix:
            self.std_decoder = lambda x: torch.ones(
                *x.shape[:-1], self.action_dim, dtype=torch.float, device=self.device)*self.log_std_init
        else:
            self.add_module("std_decoder", torch.nn.Sequential(
                self.activation(),
                torch.nn.Linear(self.actor_cfg['mlp']['shape'][-1], self.action_dim),
            ))

    def forward(self, state:torch.Tensor, preference:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = torch.cat([state, preference], dim=-1)
        x = self.model(x)
        mean = self.last_activation(self.mean_decoder(x))
        log_std = self.std_decoder(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = torch.exp(log_std)
        return mean, log_std, std
    
    def updateActionDist(self, state:torch.Tensor, preference:torch.Tensor, epsilon:torch.Tensor) -> None:
        self.action_mean, self.action_log_std, self.action_std = self.forward(state, preference)
        self.action_dist = torch.distributions.Normal(self.action_mean, self.action_std)
        self.action_dist = TransformedDistribution(self.action_dist, TanhTransform())
        self.normal_action = self.action_mean + epsilon*self.action_std
