import math

import torch
from typing import Optional
from torch import nn


def init_weights(m):
    """
    Custom weight initialization function.
    - Uses Kaiming Uniform for ReLU/LeakyReLU layers.
    - Uses Xavier Uniform for others.
    """
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)


def generate_MLP(
    in_dim=1,
    out_dim=1,
    width=128,
    n_layers=3,
    final_activation="",
    device="cpu",
    add_norm="none",
    dropout=0.0,
):
    if add_norm == "none":
        net = generate_MLP_basic(
            in_dim, out_dim, width, n_layers, final_activation, device, dropout
        )
    elif add_norm == "ln":
        net = generate_MLP_ln(
            in_dim, out_dim, width, n_layers, final_activation, device, dropout
        )

    return net


def generate_MLP_basic(
    in_dim=1,
    out_dim=1,
    width=128,
    n_layers=3,
    final_activation="",
    device="cpu",
    dropout=0.0,
):
    net_elements = []
    for i in range(n_layers):
        net_elements.append(torch.nn.Linear(in_dim, width))
        if dropout > 0.0:
            net_elements.append(torch.nn.Dropout(p=dropout))
        net_elements.append(torch.nn.LeakyReLU())
        in_dim = width

    net_elements.append(torch.nn.Linear(in_dim, out_dim))

    if final_activation == "tanh":
        net_elements.append(torch.nn.Tanh())
    elif final_activation == "sig":
        net_elements.append(torch.nn.Sigmoid())
    elif final_activation == "relu":
        net_elements.append(torch.nn.ReLU())
    elif final_activation == "LeakyReLU":
        net_elements.append(torch.nn.LeakyReLU())
    else:
        pass

    net = torch.nn.Sequential(*net_elements).float().to(device=device)

    return net


def generate_MLP_ln(
    in_dim=1,
    out_dim=1,
    width=128,
    n_layers=3,
    final_activation="",
    device="cpu",
    dropout=0.0,
):
    net_elements = []
    for i in range(n_layers):
        net_elements.append(torch.nn.Linear(in_dim, width))
        net_elements.append(torch.nn.LayerNorm(width))
        net_elements.append(torch.nn.LeakyReLU())
        if dropout > 0.0:
            net_elements.append(torch.nn.Dropout(p=dropout))

        in_dim = width

    net_elements.append(torch.nn.Linear(in_dim, out_dim))

    if final_activation == "tanh":
        net_elements.append(torch.nn.Tanh())
    elif final_activation == "sig":
        net_elements.append(torch.nn.Sigmoid())
    elif final_activation == "relu":
        net_elements.append(torch.nn.ReLU())
    elif final_activation == "LeakyReLU":
        net_elements.append(torch.nn.LeakyReLU())
    else:
        pass

    net = torch.nn.Sequential(*net_elements).float().to(device=device)

    return net


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        position = torch.arange(0, max_len).unsqueeze(1)  # [L, 1]
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)  # [L, D]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        T = x.size(1)
        return x + self.pe[:T, :]


class LearnablePositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D]
        B, T, _ = x.shape
        pos_ids = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)  # [B, T]
        return x + self.pos_embedding(pos_ids)


class TransformerIndSA(nn.Module):
    def __init__(
        self,
        obs_dim,
        action_dim,
        out_dim,
        num_layers=2,
        num_heads=2,
        hidden_dim=64,
        dropout=0.1,
        out_layer_hidden_dim: int = 128,
        out_layer_n_hidden: int = 2,
        max_len: int = 64,
        pos_encoding: str = "sinusoidal",  # "sinusoidal" | "learned" | "none"
        device="cpu",
        output_mode: str = "actions",  # "all" | "actions" | "states"
        **ignored,
    ):
        super().__init__()
        self.device = device
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.output_mode = output_mode

        # Project each token type directly to hidden_dim (no concat).
        self.state_embedding = nn.Linear(self.obs_dim, hidden_dim, device=self.device)
        self.action_embedding = nn.Linear(
            self.action_dim, hidden_dim, device=self.device
        )

        # Token-type (segment) embedding: 0 = state, 1 = action.
        self.token_type_embedding = nn.Embedding(2, hidden_dim, device=self.device)

        # --- Positional encoding choice ---
        pos_encoding = pos_encoding.lower()
        if pos_encoding == "sinusoidal":
            self.pos_encoder = SinusoidalPositionalEncoding(
                d_model=hidden_dim, max_len=2 * max_len  # doubled, since we interleave
            ).to(self.device)
        elif pos_encoding == "learned":
            self.pos_encoder = LearnablePositionalEncoding(
                d_model=hidden_dim, max_len=2 * max_len
            ).to(self.device)
        elif pos_encoding == "none":
            self.pos_encoder = None
        else:
            raise ValueError(
                f"Unknown pos_encoding: {pos_encoding}. Use 'sinusoidal', 'learned', or 'none'."
            )

        self.transformer_encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            device=self.device,
            batch_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            self.transformer_encoder_layer,
            num_layers=num_layers,
        )

        self.fc_out = generate_MLP(
            in_dim=hidden_dim,
            out_dim=out_dim,
            n_layers=out_layer_n_hidden,
            width=out_layer_hidden_dim,
            device=self.device,
        )

    def _build_causal_mask(self, seq_len: int, device):
        # True = masked (disallow attending). Shape [S, S]
        return torch.triu(
            torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1
        )

    def forward(self, x, src_key_padding_mask: Optional[torch.Tensor] = None):
        """
        x: [B, T, obs_dim + action_dim] or [T, D] (we keep your eval-batch convenience)
        src_key_padding_mask (optional): [B, 2T] boolean mask where True = pad position.
                                         If you have variable lengths, build this after interleaving.
        """
        eval_batch = False
        if x.dim() == 2:
            x = x.unsqueeze(0)  # [1, T, D]
            eval_batch = True

        B, T, D = x.shape
        # Split
        states = x[..., : self.obs_dim]  # [B, T, obs_dim]
        actions = x[
            ..., self.obs_dim : self.obs_dim + self.action_dim
        ]  # [B, T, action_dim]

        # Embed to hidden
        s_h = self.state_embedding(states)  # [B, T, H]
        a_h = self.action_embedding(actions)  # [B, T, H]

        # Interleave as tokens: [s1, a1, s2, a2, ..., sT, aT] -> [B, 2T, H]
        tokens = torch.stack((s_h, a_h), dim=2).reshape(B, 2 * T, self.hidden_dim)

        # Token-type ids 0,1,0,1,... of shape [1, 2T]
        type_ids = torch.arange(2 * T, device=self.device) % 2
        tokens = tokens + self.token_type_embedding(type_ids)[None, :, :]

        # Positional encoding (if any)
        if self.pos_encoder is not None:
            tokens = self.pos_encoder(tokens)  # [B, 2T, H]

        # Causal mask so each token only sees itself and past
        causal_mask = self._build_causal_mask(seq_len=2 * T, device=self.device)

        # If you have variable-length sequences, pass src_key_padding_mask with shape [B, 2T]
        # where True means "this position is padding".
        h = self.transformer_encoder(
            tokens,
            mask=causal_mask,
            src_key_padding_mask=src_key_padding_mask,
        )  # [B, 2T, H]

        y = self.fc_out(h)  # [B, 2T, out_dim]

        # Select which positions to return
        if self.output_mode == "all":
            out = y  # [B, 2T, out_dim]
        elif self.output_mode == "actions":
            out = y[:, 1::2, :]  # keep action tokens -> [B, T, out_dim]
        elif self.output_mode == "states":
            out = y[:, 0::2, :]  # keep state tokens  -> [B, T, out_dim]
        else:
            raise ValueError("output_mode must be one of {'all','actions','states'}")

        if eval_batch:
            out = out.squeeze(0)
        return out
