import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pdb
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class SimpleMLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dims, dropout=0.0):
        super(SimpleMLP, self).__init__()

        self.nlayers = len(hidden_dims)

        layers = []
        dim = in_dim
        for i in range(self.nlayers):
            layer = nn.Linear(dim, hidden_dims[i])
            layers.append(layer)
            dim = hidden_dims[i]
        self.fcs = nn.ModuleList(layers)
        self.fc_out = nn.Linear(dim, out_dim)
        self.dropout = nn.Dropout(dropout)

        self._reset_parameters()
        # self.apply(self._init_weights)

    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        for fc in self.fcs:
            x = F.relu(fc(x))
            x = self.dropout(x)
        x = self.fc_out(x)
        return x


class TranslationMLP(nn.Module):
    def __init__(self, Sin, Sout, Tin, Tout, hidden_dims, dropout=0.0):
        super(TranslationMLP, self).__init__()

        self.Sin = Sin
        self.Sout = Sout
        self.Tin = Tin
        self.Tout = Tout
        self.mlp = SimpleMLP(
            in_dim=Sin * Tin,
            out_dim=Sout * Tout,
            hidden_dims=hidden_dims,
            dropout=dropout,
        )

    def forward(self, x):
        """
        Args:
            x: (N, T, Sin)
            y: (N, T, Sout)
        """
        x = self.mlp(x.flatten(start_dim=1))
        x = x.reshape(x.shape[0], self.Tout, self.Sout)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float = 0.0, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pos_embedding", pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(
            token_embedding + self.pos_embedding[: token_embedding.size(0), :]
        )


class ScaleEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout: float = 0.0):
        super(ScaleEncoding, self).__init__()

        self.proj = nn.Linear(2, emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, scale: torch.Tensor):
        """
        Args:
            x: (S, B, d_model)
            scale: (S, B, 2)
        """
        return self.dropout(x + self.proj(scale))


class TimeEncoder(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        seq_len: int = 5000,
        dropout: float = 0.0,
        batch_first: bool = False,
        norm_first: bool = False,
    ):
        super(TimeEncoder, self).__init__()
        self.model_type = "Time"
        self.d_model = d_model

        self.has_linear_in = d_in != d_model
        if self.has_linear_in:
            self.linear_in = nn.Linear(d_in, d_model)

        self.pos_encoder = PositionalEncoding(d_model, dropout, seq_len + 1)

        encoder_layers = TransformerEncoderLayer(
            d_model,
            nhead,
            dim_feedforward=d_hid,
            dropout=dropout,
            batch_first=batch_first,
            norm_first=norm_first,
        )
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        # self._reset_parameters()
        self.apply(self._init_weights)

    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: tensor of shape (seq_len, batch_size, d_in)
        Returns:
            y: tensor of shape (batch_size, d_model)
        """
        if self.has_linear_in:
            x = self.linear_in(x)
        # pos encoder
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)  # (seq_len, batch, d_model)

        return x


class SensorEncoder(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_model: int,
        nhead: int,
        d_hid: int,
        nlayers: int,
        seq_len: int = 5000,
        dropout: float = 0.0,
        batch_first: bool = False,
        norm_first: bool = False,
    ):
        super(SensorEncoder, self).__init__()
        self.model_type = "Sensor"
        self.d_model = d_model

        self.has_linear_in = d_in != d_model
        if self.has_linear_in:
            self.linear_in = nn.Linear(d_in, d_model)

        self.pos_encoder = PositionalEncoding(d_model, dropout, seq_len)
        self.scale_encoder = ScaleEncoding(d_model, dropout)

        encoder_layers = TransformerEncoderLayer(
            d_model,
            nhead,
            dim_feedforward=d_hid,
            dropout=dropout,
            batch_first=batch_first,
            norm_first=norm_first,
        )
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)

        self.proj = nn.Linear(seq_len * d_model, d_model)

        # self._reset_parameters()
        self.apply(self._init_weights)

    def _reset_parameters(self):
        """Initiate parameters in the transformer model."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: tensor of shape (seq_len, batch_size, d_in)
            scale: tensor of shape (seq_len, batch_size, 2)
        Returns:
            y: tensor of shape (batch_size, seq_out_len)
        """
        if self.has_linear_in:
            x = self.linear_in(x)
        # scale encoding
        x = self.scale_encoder(x, scale)
        # pos encoder
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)  # (seq_len, batch, d_model)

        return x


class SensorTimeEncoder(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_model: int,
        nheadt: int,
        nheads: int,
        d_hid: int,
        nlayerst: int,
        nlayerss: int,
        time_in: int = 5000,
        time_out: int = 5000,
        compression: int = 3,
        sens_in: int = 5,
        sens_out: int = 5,
        dropout: float = 0.0,
    ):
        super(SensorTimeEncoder, self).__init__()

        self.Tin = time_in
        self.Tout = time_out

        self.timeenc = TimeEncoder(
            d_in=d_in,
            d_model=d_model,
            nhead=nheadt,
            d_hid=d_hid,
            nlayers=nlayerst,
            seq_len=time_in // compression,
            dropout=dropout,
        )

        self.senorenc = SensorEncoder(
            d_in=d_model,
            d_model=d_model,
            nhead=nheads,
            d_hid=d_hid,
            nlayers=nlayerss,
            seq_len=sens_in * time_in // compression,
            dropout=dropout,
        )

        self.classifier = nn.Linear(
            sens_in * time_in // compression * d_model, time_out * sens_out
        )

    def forward(self, x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: tensor of shape (B, T, Sin, d_in)
            scale: tensor of shape (B, Sin, 2)
        Returns:
            y: tensor of shape (B, T, Sout)
        """

        B, TCin, Sin, dim = x.shape
        x = torch.permute(x, (1, 0, 2, 3))  # (TCin, B, Sin, d_in)

        # prepare input for time encoder
        x = x.flatten(start_dim=1, end_dim=2)  # (TCin, B * Sin, d_in)
        y = self.timeenc(x)  # (TCin, B * Sin, d_model)

        # prepare input to sensor encoder
        y = y.reshape(TCin, B, Sin, y.shape[-1])  # (TCin, B, Sin, d_model)
        y = y.transpose(1, 2)  # (TCin, Sin, B, d_model)
        y = y.reshape(TCin * Sin, B, -1)
        scale = torch.permute(scale, (1, 0, 2))  # (Sin, B, 2)
        scale = scale.unsqueeze(0).expand(TCin, -1, -1, -1)  # (TCin, Sin, B, 2)
        scale = scale.reshape(TCin * Sin, B, 2)
        z = self.senorenc(y, scale)  # (seq, B, d_model)

        z = torch.permute(z, (1, 0, 2))
        z = z.flatten(start_dim=1)
        z = self.classifier(F.relu(z))
        z = z.reshape(B, self.Tout, -1)  # (B, Tout, Sout)

        return z
