__all__ = [
    "DetermPolicy",
    "StochaPolicy",
]
import numpy as np
import torch
import torch.nn as nn
from gops.utils.common_utils import get_activation_func
from gops.utils.act_distribution_cls import Action_Distribution
from .ode_cell.wirings import Connected
from .ode_cell.SmODE import SmODE
def SmODE(sensory_units, command_units, state_dim, output_dim, lambda1, lambda2):
    assert command_units - output_dim >= 0, "command_units must larger than output_dim"
    wiring = Connected(
        sensory_units,
        command_units,
        output_dim,
        sensory_units,
        command_units,
        output_dim,
        output_dim,
    )
    SmODE = SmODE(state_dim, wiring, batch_first=True, lambda1=lambda1, lambda2=lambda2)
    return SmODE
def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes) - 1):
        act = activation if j < len(sizes) - 2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j + 1]), act()]
    return nn.Sequential(*layers)
class DetermPolicy(nn.Module, Action_Distribution):
    """
    Approximated function of deterministic policy.
    Input: observation.
    Output: action.
    """
    def __init__(self, **kwargs):
        super().__init__()
        obs_dim = kwargs["obs_dim"]
        act_dim = kwargs["act_dim"]
        hidden_sizes = kwargs["hidden_sizes"]
        sensory_units = kwargs["sensory_units"]
        command_units = kwargs["command_units"]
        lambda1 = kwargs["lambda1"]
        lambda2 = kwargs["lambda2"]
        pi_sizes = [obs_dim] + list(hidden_sizes) 
        self.pi = mlp(
            pi_sizes,
            get_activation_func(kwargs["hidden_activation"]),
            get_activation_func(kwargs["output_activation"]),
        )
        self.ode = SmODE(sensory_units, command_units, hidden_sizes[-1], act_dim, lambda1=lambda1, lambda2=lambda2)
        self.hx = None
        self.register_buffer("act_high_lim", torch.from_numpy(kwargs["act_high_lim"]))
        self.register_buffer("act_low_lim", torch.from_numpy(kwargs["act_low_lim"]))
        self.action_distribution_cls = kwargs["action_distribution_cls"]
        self.regular_loss = 0
    def set_hx_none(self):
        self.hx = None
    def forward(self, obs, sample=False, **kwargs):
        x = self.pi(obs)
        x = x.reshape(-1, 1, x.shape[-1])
        if sample:
            x, self.hx, _ = self.ode(x, self.hx, sample=sample)
        else:
            x, _, self.regular_loss = self.ode(x, None, sample=sample)
        x = x.reshape(-1, x.shape[-1])
        action = (self.act_high_lim - self.act_low_lim) / 2 * torch.tanh(x) + (
            self.act_high_lim + self.act_low_lim
        ) / 2
        if "training" in kwargs:
            return action, self.regular_loss
        return action
class StochaPolicy(nn.Module, Action_Distribution):
    """
    Approximated function of stochastic policy.
    Input: observation.
    Output: parameters of action distribution.
    """
    def __init__(self, **kwargs):
        super().__init__()
        obs_dim = kwargs["obs_dim"]
        act_dim = kwargs["act_dim"]
        hidden_sizes = kwargs["hidden_sizes"]
        self.std_type = kwargs["std_type"]
        sensory_units = kwargs["sensory_units"]
        command_units = kwargs["command_units"]
        lambda1 = kwargs["lambda1"]
        lambda2 = kwargs["lambda2"]
        if self.std_type == "mlp_separated":
            mean_pi_sizes = [obs_dim] + list(hidden_sizes) 
            std_pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim]
            self.ode = SmODE(sensory_units, command_units, hidden_sizes[-1], act_dim, lambda1=lambda1, lambda2=lambda2)
            self.hx = None
            self.regular_loss = 0
            self.mean = mlp(
                mean_pi_sizes,
                get_activation_func(kwargs["hidden_activation"]),
                get_activation_func(kwargs["output_activation"]),
            )
            self.log_std = mlp(
                std_pi_sizes,
                get_activation_func(kwargs["hidden_activation"]),
                get_activation_func(kwargs["output_activation"]),
            )
        elif self.std_type == "mlp_shared":
            pi_sizes = [obs_dim] + list(hidden_sizes)  
            self.ode = SmODE(sensory_units, command_units, hidden_sizes[-1], act_dim * 2, lambda1=lambda1, lambda2=lambda2)
            self.hx = None
            self.regular_loss = 0
            self.policy = mlp(
                pi_sizes,
                get_activation_func(kwargs["hidden_activation"]),
                get_activation_func(kwargs["output_activation"]),
            )
        elif self.std_type == "parameter":
            pi_sizes = [obs_dim] + list(hidden_sizes) 
            self.ode = SmODE(sensory_units, command_units, hidden_sizes[-1], act_dim, lambda1=lambda1, lambda2=lambda2)
            self.hx = None
            self.regular_loss = 0
            self.policy = mlp(
                pi_sizes,
                get_activation_func(kwargs["hidden_activation"]),
                get_activation_func(kwargs["output_activation"]),
            )
            self.log_std = nn.Parameter(-0.5 * torch.ones(1, act_dim))
        self.min_log_std = kwargs["min_log_std"]
        self.max_log_std = kwargs["max_log_std"]
        self.register_buffer("act_high_lim", torch.from_numpy(kwargs["act_high_lim"]))
        self.register_buffer("act_low_lim", torch.from_numpy(kwargs["act_low_lim"]))
        self.action_distribution_cls = kwargs["action_distribution_cls"]
    def set_hx_none(self):
        self.hx = None
    def forward(self, obs, sample=False, **kwargs):
        if self.std_type == "mlp_separated":
            x = self.mean(obs)
            x = x.reshape(-1, 1, x.shape[-1])
            if sample:
                x, self.hx, _ = self.ode(x, self.hx, sample=sample)
            else:
                x, _, self.regular_loss = self.ode(x, None, sample=sample)
            action_mean = x.reshape(-1, x.shape[-1])
            action_std = torch.clamp(
                self.log_std(obs), self.min_log_std, self.max_log_std
            ).exp()
        elif self.std_type == "mlp_shared":
            x = self.policy(obs)
            x = x.reshape(-1, 1, x.shape[-1])
            if sample:
                x, self.hx, _ = self.ode(x, self.hx, sample=sample)
            else:
                x, _, self.regular_loss = self.ode(x, None, sample=sample)
            x = x.reshape(-1, x.shape[-1])
            action_mean, action_log_std = torch.chunk(
                x, chunks=2, dim=-1
            )  
            action_std = torch.clamp(
                action_log_std, self.min_log_std, self.max_log_std
            ).exp()
        elif self.std_type == "parameter":
            x = self.policy(obs)
            x = x.reshape(-1, 1, x.shape[-1])
            if sample:
                x, self.hx, _ = self.ode(x, self.hx, sample=sample)
            else:
                x, _, self.regular_loss = self.ode(x, None, sample=sample)
            action_mean = x.reshape(-1, x.shape[-1])
            action_log_std = self.log_std + torch.zeros_like(action_mean)
            action_std = torch.clamp(
                action_log_std, self.min_log_std, self.max_log_std
            ).exp()
        if "training" in kwargs:
            return torch.cat((action_mean, action_std), dim=-1), self.regular_loss
        return torch.cat((action_mean, action_std), dim=-1)
