import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .attention import SelfAttention, CrossAttention

class TimeEmbedding(nn.Module):
    def __init__(self, n_embed: int):
        super().__init__()
        self.linear_1 = nn.Linear(n_embed, 4 * n_embed)
        self.linear_2 = nn.Linear(4 * n_embed, 4 * n_embed)
        
    def forward(self, x: torch.Tensor):
        x = F.silu(self.linear_1(x))
        x = self.linear_2(x) # (1, 1280)
        return x
    
class ContextEmbedding(nn.Module):
    def __init__(self, n_embed: int):
        super().__init__()
        self.linear_1 = nn.Linear(n_embed, 4 * n_embed)
        self.linear_2 = nn.Linear(4 * n_embed, n_embed)
        
    def forward(self, x: torch.Tensor):
        x = F.silu(self.linear_1(x))
        x = self.linear_2(x) # (1, 1280)
        return x
    
class PositionEmbedding(nn.Module):
    def __init__(self, embed_dim=16, max_positions=64):
        super(PositionEmbedding, self).__init__()
        self.embed_dim = embed_dim
        self.max_positions = max_positions

        # Compute the sinusoidal position embeddings of shape (seq_len, embed_dim)
        positions = torch.arange(0, max_positions).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * -(np.log(10000.0) / embed_dim))
        pos_enc = torch.zeros(max_positions, embed_dim)
        pos_enc[:, 0::2] = torch.sin(positions.unsqueeze(1) * div_term.unsqueeze(0))
        pos_enc[:, 1::2] = torch.cos(positions.unsqueeze(1) * div_term.unsqueeze(0))

        # Register the position embeddings as a buffer
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        # x: (batch_size, seq_len, dim)
        return x + self.pos_enc[:x.size(1), :].unsqueeze(0)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_time=1280):
        super().__init__()
        self.groupnorm_feature = nn.GroupNorm(16, in_channels) # original: 32
        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.linear_time = nn.Linear(n_time, out_channels)

        self.groupnorm_merged = nn.GroupNorm(16, out_channels) # original: 32
        self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

    def forward(self, feature, time):
        residue = feature

        feature = self.groupnorm_feature(feature)
        feature = F.silu(feature)
        feature = self.conv_feature(feature)

        time = F.silu(time)
        time = self.linear_time(time)

        merged = feature + time.unsqueeze(-1).unsqueeze(-1)
        merged = self.groupnorm_merged(merged)
        merged = F.silu(merged)
        merged = self.conv_merged(merged)

        return merged + self.residual_layer(residue)

class SpatialTransformer(nn.Module):
    def __init__(self, in_channels, dim):
        super(SpatialTransformer, self).__init__()

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.loc_in = int(in_channels * (dim / 4) * (dim / 4))
        self.fc_loc = nn.Sequential(
            nn.Linear(self.loc_in, 32),
            nn.ReLU(True),
            nn.Linear(32, 6)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, self.loc_in)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x): 
        x = self.stn(x)
        return x
    

class AttentionBlock(nn.Module):
    def __init__(self, n_head: int, n_embd: int, dim: int, d_context=320):
        super().__init__()
        channels = n_head * n_embd
        
        self.groupnorm = nn.GroupNorm(16, channels, eps=1e-6) # original: 32
        self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

        self.layernorm_1 = nn.LayerNorm(channels)
        self.attention = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
        self.layernorm_2 = nn.LayerNorm(channels)
        self.linear_geglu_1  = nn.Linear(channels, 4 * channels * 2)
        self.linear_geglu_2 = nn.Linear(4 * channels, channels)

        self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
        
        self.spatial_transformer = SpatialTransformer(channels, dim)
    
    def forward(self, x, context):
        x = self.spatial_transformer(x)
        
        residue_long = x

        x = self.groupnorm(x)
        x = self.conv_input(x)
        
        n, c, h, w = x.shape
        x = x.view((n, c, h * w))   # (n, c, hw)
        x = x.transpose(-1, -2)  # (n, hw, c)

        # Cross Attention with the conditional embedding
        residue_short = x
        x = self.layernorm_1(x)
        x = self.attention(x, context)
        x += residue_short

        residue_short = x
        x = self.layernorm_2(x)
        x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
        x = x * F.gelu(gate)
        x = self.linear_geglu_2(x)
        x += residue_short

        x = x.transpose(-1, -2)  # (n, c, hw)
        x = x.view((n, c, h, w))    # (n, c, h, w)

        return self.conv_output(x) + residue_long
    
class Upsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)
    
    
class SwitchSequential(nn.Sequential):
    def forward(self, x, context, time):
        for layer in self:
            if isinstance(layer, AttentionBlock):
                x = layer(x, context)
            elif isinstance(layer, ResidualBlock):
                x = layer(x, time)
            elif isinstance(layer, PositionEmbedding):
                input_shape = x.shape
                x = x.reshape(input_shape[0], input_shape[1], -1)
                x = layer(x).reshape(input_shape)
            else:
                x = layer(x)
        return x

class UNet(nn.Module):
    def __init__(self, d_context=320):
        super().__init__()
        self.encoders = nn.ModuleList([
            SwitchSequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), 
                             nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), 
                             PositionEmbedding(16, 64)), 
            SwitchSequential(ResidualBlock(64, 64)), 
        ])
        self.bottleneck = SwitchSequential(
            ResidualBlock(64, 64),
            AttentionBlock(4, 16, dim=4, d_context=d_context), 
            ResidualBlock(64, 64),
        )
        self.decoders = nn.ModuleList([
            SwitchSequential(ResidualBlock(128, 64), AttentionBlock(4, 16, dim=4, d_context=d_context)), 
            SwitchSequential(ResidualBlock(128, 64), Upsample(64)),
        ])

    def forward(self, x, context, time):
        skip_connections = []
        for layers in self.encoders:
            x = layers(x, context, time)
            skip_connections.append(x)

        x = self.bottleneck(x, context, time)

        for layers in self.decoders:
            x = torch.cat((x, skip_connections.pop()), dim=1)
            x = layers(x, context, time)
        
        return x


class FinalLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(16, in_channels) 
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.groupnorm(x)
        x = F.silu(x)
        x = self.conv(x)
        return x
        

class Diffusion(nn.Module):
    def __init__(self, d_context=320, context_embed=False):
        super().__init__()
        self.time_embedding = TimeEmbedding(320) 
        if context_embed:
            self.context_embedding = ContextEmbedding(d_context)
        self.unet = UNet(d_context=d_context)
        self.final = FinalLayer(64, 1) 
        
    def forward(self, latent: torch.Tensor, prompt: torch.Tensor, time: torch.Tensor):
        time = self.time_embedding(time) 

        if hasattr(self, 'context_embedding'):
            prompt = self.context_embedding(prompt)
        output = self.unet(latent, prompt, time)
        
        output = self.final(output)
        
        return output