import math
import torch
import torch.nn as nn

import utils

class UnscrambleCNN(nn.Module):
    def __init__(
        self, in_channels, num_pieces, image_size,
        hidden_channels, kernel_size, stride, padding, out_dim
    ):
        super().__init__()

        piece_size = image_size // num_pieces

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(True),
            nn.MaxPool2d(2),
        )

        piece_size_after_conv = (piece_size) // (2 * 1)
        dim_after_conv = piece_size_after_conv**2 * hidden_channels
        mlp_hid_dim = (dim_after_conv + out_dim) // 2

        self.mlp = nn.Sequential(
            nn.Linear(dim_after_conv, mlp_hid_dim),
            nn.ReLU(True),
            nn.Linear(mlp_hid_dim, out_dim)
        )

    def forward(self, batch_pieces):
        """
        Args:
            batch_pieces: shape [bs, num_pieces**2, c, h, w]

        Returns:
            shape [bs, num_pieces**2, out_dim]
        """
        batch_shape = batch_pieces.shape[:-3]
        batch_pieces = batch_pieces.flatten(end_dim=-4)

        conv_pieces = self.conv(batch_pieces) # [bs*n, hidden_c, h, w]
        conv_pieces_flatten = conv_pieces.flatten(start_dim=-3)
        pieces_embd = self.mlp(conv_pieces_flatten) # [bs*n, out_dim]
        pieces_embd = pieces_embd.unflatten(0, batch_shape) # [bs, n, out_dim]

        return pieces_embd
    
class SortMnistCNN(nn.Module):
    def __init__(
            self, in_channels, num_digits, image_size,
            hidden_channels1, kernel_size1, stride1, padding1,
            hidden_channels2, kernel_size2, stride2, padding2, out_dim,
        ):
        super(SortMnistCNN, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels1, kernel_size1, stride1, padding1),
            nn.BatchNorm2d(hidden_channels1),
            nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(hidden_channels1, hidden_channels2, kernel_size2, stride2, padding2),
            nn.BatchNorm2d(hidden_channels2),
            nn.ReLU(True),
            nn.MaxPool2d(2),
        )

        piece_size_after_conv = (image_size) // (2 * 2)
        dim_after_conv = (piece_size_after_conv ** 2) * hidden_channels2 * num_digits

        self.mlp = nn.Linear(dim_after_conv, out_dim)

    def forward(self, pieces):
        """
        Args:
            pieces: shape [bs, num_pieces, c, h, w]

        Returns:
            shape [bs, num_pieces, out_dim]
        """
        batch_shape = pieces.shape[:-3]
        pieces = pieces.flatten(end_dim=-4)

        conv_pieces = self.conv(pieces) # [bs*n, hidden_c, h, w]
        conv_pieces_flatten = conv_pieces.flatten(start_dim=-3)
        pieces_embd = self.mlp(conv_pieces_flatten) # [bs*n, out_dim]
        pieces_embd = pieces_embd.unflatten(0, batch_shape) # [bs, n, out_dim]

        return pieces_embd

class TimeEmbedding(nn.Module):
    """
    Sinusoidal position embeddings for time t
    """
    def __init__(self, dim, max_period=10000):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        self.inv_freq = torch.exp(torch.arange(half_dim) * (-math.log(max_period) / (half_dim - 1)))

    def forward(self, input):
        shape = input.shape
        input = input.reshape(-1).float()
        sinusoid_in = torch.outer(input, self.inv_freq.to(input.device))
        pos_emb = torch.cat([sinusoid_in.sin(), sinusoid_in.cos()], dim=-1)
        pos_emb = pos_emb.reshape(*shape, self.dim)

        return pos_emb
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a long enough 'pe' matrix with dimensions [max_len, d_model]
        self.pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)

        # Add a dimension for batch size
        self.pe = self.pe.unsqueeze(0)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        # Add positional encoding to the input embedding
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return self.dropout(x)

class ReverseDiffusion(nn.Module):
    """ 
    p_{theta}(x_{t-1} | x_t)
    """
    def __init__(
        self, dataset, in_channels, num_pieces, image_size,
        hidden_channels1, kernel_size1, stride1, padding1,
        hidden_channels2, kernel_size2, stride2, padding2, num_digits,
        d_model, nhead, d_hid, nlayers, dropout, d_out_adjust, # -1, 0, 1, square
        use_pos_enc=True,
    ):
        super(ReverseDiffusion, self).__init__()

        self.time_embd = TimeEmbedding(d_model)

        if dataset == "unscramble-CIFAR10":
            self.pieces_embd = UnscrambleCNN(
                in_channels, num_pieces, image_size,
                hidden_channels1, kernel_size1, stride1, padding1, d_model
            )
        
        elif dataset == "sort-MNIST":
            self.pieces_embd = SortMnistCNN(
                in_channels, num_digits, image_size,
                hidden_channels1, kernel_size1, stride1, padding1,
                hidden_channels2, kernel_size2, stride2, padding2, d_model
            )

        else:
            raise NotImplementedError

        self.use_pos_enc = use_pos_enc

        if use_pos_enc:
            self.pos_encoder = PositionalEncoding(d_model, dropout)

        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)

        if d_out_adjust == "-1":
            self.dmodel_mlp = nn.Sequential(
                nn.Linear(2 * d_model, 2 * d_hid),
                nn.ReLU(),
                nn.Linear(2 * d_hid, 1)
            )
        elif d_out_adjust == "0" or d_out_adjust == "1":
            self.dmodel_mlp = nn.Sequential(
                nn.Linear(d_model, d_hid),
                nn.ReLU(),
                nn.Linear(d_hid, 1)
            )

        self.d_model = d_model
        self.d_out_adjust = d_out_adjust

    def training_patch_embd(self, src, time, x_start):
        """
        Args:
            src: permutations, [N, T, b, n]
            time: [N, T, b]
            x_start: [b, n, c, h, w]

        Returns:
            embedding of patches, shape [NTb, n, d]
        """
        time = self.time_embd(time).unsqueeze(-2) # [N, T, b, 1, d_model]

        patch_embd = self.pieces_embd(x_start) # [b, n, d_model]
        src = utils.permute_embd(src, patch_embd) # [N, T, b, n, d_model]

        src = src + time # [N, T, b, n, d_model]
        src = torch.flatten(src, end_dim=-3) # [NTb, n, d_model]

        return src
        
    def eval_patch_embd(self, src, time, x_original):
        """
        Args:
            src: permutations [batch, beam, n]
            time: [batch]
            x_original: [batch, n, c, h, w]

        Returns:
            embdedding of patches, shape [batch*beam, n, d]
        """
        time = self.time_embd(time).unsqueeze(-2).unsqueeze(-2) # [batch, 1, 1, d_model]

        patch_embd = self.pieces_embd(x_original) # [batch, n, d_model]
        patch_embd = patch_embd.unsqueeze(-3) # [batch, 1, n, d_model]
        src = utils.permute_embd(src, patch_embd) # [batch, beam, n, d_model]

        src = src + time # [batch, beam, n, d_model]
        src = torch.flatten(src, end_dim=-3) # [batch*beam, n, d]

        return src

    def forward(self, src, time, x_start):
        """
        Args:
            src: [N, T, b, n], or [b, beam, n]
            time: [N, T, b], or [b]
            x_start: [b, n, object_shape]

        Returns:
            logits of x_{t-1}
        """
        device = src.device
        batch_shape = src.shape[:-1]
        n = src.size(-1)

        if self.training:
            src = self.training_patch_embd(src, time, x_start)
        else:
            src = self.eval_patch_embd(src, time, x_start)

        batch_shape_flattened = src.size(0)
        mask = None

        if self.d_out_adjust == "1":
            pad = torch.zeros(batch_shape_flattened, 1, self.d_model, device=device)
            src = torch.cat([src, pad], dim=-2) # [NTb, n+1, d_model]

        if self.d_out_adjust == "square":
            pad = torch.zeros(batch_shape_flattened, n, self.d_model, device=device)
            src = torch.cat([src, pad], dim=-2) # [NTb, 2n, d_model]

            # Create mask, shape [2n, 2n]
            mask_up_right = torch.full((n, n), float("-inf"), device=device)
            mask_bottom_right = torch.triu(torch.full((n, n), float("-inf"), device=device))
            mask_right = torch.cat([mask_up_right, mask_bottom_right], dim=0)
            mask_left = torch.zeros((2 * n, n), device=device)
            mask = torch.cat([mask_left, mask_right], dim=-1)
        
        if self.use_pos_enc:
            src = self.pos_encoder(src)
        
        out = self.transformer_encoder(src, mask=mask) # [NTb, 2n, d_model]

        if self.d_out_adjust == "-1":
            last = out[:, [-1], :].repeat(1, n - 1, 1) # [NTb, n-1, d_model]
            front = out[:, :-1, :] # [NTb, n-1, d_model]
            out = torch.cat([front, last], dim=-1) # [NTb, n-1, 2d_model]

        if self.d_out_adjust == "square":
            row, col = torch.split(out, [n, n], dim=-2) # [NTb, n, d]
            combined_out = torch.matmul(row, col.transpose(-1, -2)) # [NTb, n, n]
            combined_out = combined_out.unflatten(0, batch_shape)
            return combined_out
        else:
            out = self.dmodel_mlp(out) # [NTb, n(+-1), 1]
            out = out.unflatten(0, batch_shape).flatten(start_dim=-2) # [N, T, b, n(+-1)]
            return out
    