import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import distributions as pyd

from src.modules.nn import MLP, CrossAttention, SelfAttention, LayerNorm


def soft_clamp(
    x : torch.Tensor,
    _min=None,
    _max=None,
) -> torch.Tensor:
    # clamp tensor values while mataining the gradient
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x


class SelfAttentionBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_1 = SelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x, z):
        x = x + self.attn_1(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x
    

class CrossAttentionBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn_1 = CrossAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        # self.attn_2 = SelfAttention(config)
        # self.ln_3 = LayerNorm(config.n_embd, bias=config.bias)
        self.mlp = MLP(config)

    def forward(self, x, z):
        x = x + self.attn_1(self.ln_1(x), z)
        # x = x + self.attn_2(self.ln_2(x))
        x = x + self.mlp(self.ln_2(x))
        return x



class TanhNormalWrapper(torch.distributions.Normal):
    def __init__(self, loc, scale, max_action):
        super().__init__(loc, scale)
        self._max_action = max_action

    def log_prob(self, action, raw_action=None):
        squashed_action = action/self._max_action
        if raw_action is None:
            raw_action = self.arctanh(squashed_action)
        log_prob = super().log_prob(raw_action).sum(-1, keepdim=True)
        eps = 1e-6
        log_prob = log_prob - torch.log(self._max_action*(1 - squashed_action.pow(2)) + eps).sum(-1, keepdim=True)
        return log_prob

    def mode(self):
        raw_action = self.mean
        action = self._max_action * torch.tanh(self.mean)
        return action, raw_action

    def arctanh(self, x):
        one_plus_x = (1 + x).clamp(min=1e-6)
        one_minus_x = (1 - x).clamp(min=1e-6)
        return 0.5 * torch.log(one_plus_x / one_minus_x)

    def rsample(self):
        raw_action = super().rsample()
        action = self._max_action * torch.tanh(raw_action)
        return action, raw_action
    
    
class AttentionActor(nn.Module):
    """torch.distributions implementation of an diagonal Gaussian policy."""

    def __init__(self, config):
        super().__init__()
        
        self.actor_type = config.actor_type
        self.max_action = config.max_action
        if self.actor_type == "stochastic":
            # self.log_std_bounds = config.log_std_bounds
            self.register_parameter(
                "max_logvar",
                nn.Parameter(torch.ones(config.action_dim) * config.log_std_bounds[1], requires_grad=True)
            )
            self.register_parameter(
                "min_logvar",
                nn.Parameter(torch.ones(config.action_dim) * config.log_std_bounds[0], requires_grad=True)
            )
        self.transformer = nn.ModuleDict(dict(
            h = nn.ModuleList([SelfAttentionBlock(config) for i in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        if self.actor_type == "deterministic":
            self.trunk = nn.Linear(config.n_embd, config.action_dim)
        elif self.actor_type == "stochastic":
            self.trunk = nn.Linear(config.n_embd, config.action_dim*2)
        else:
            raise KeyError

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, z):
        for block in self.transformer.h:
            x = block(x, z)
        x = x[:, -1, :]
        x = self.transformer.ln_f(x)
        if self.actor_type == "deterministic":
            pred = self.trunk(x).clip(-self.max_action, self.max_action)
            return pred
        elif self.actor_type == "stochastic":
            mu, log_std = self.trunk(x).chunk(2, dim=-1)
            # log_std = torch.tanh(log_std)
            log_std = soft_clamp(log_std, self.min_logvar, self.max_logvar)
            std = log_std.exp()
            pred = pyd.Normal(mu, std)
            return pred
        else:
            raise NotImplementedError
    
    def loss(self, pred, action):
        if self.actor_type == "deterministic":
            loss = F.mse_loss(pred, action)
        elif self.actor_type == 'stochastic':
            loss = -pred.log_prob(action).mean()

        return loss
