import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import CrossEntropyLoss
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import DropPath, to_2tuple


class PatchEmbed(nn.Module):
    """ 
    Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        # Resize x if needed instead of asserting
        if H != self.img_size[0] or W != self.img_size[1]:
            x = F.interpolate(x, size=self.img_size, mode='bilinear', align_corners=True)
            
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


class Attention(nn.Module):
    """
    Multi-head self-attention module
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MLP(nn.Module):
    """
    Feed-forward network (FFN)
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerBlock(nn.Module):
    """
    Transformer encoder block
    """
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class ResNetBlock(nn.Module):
    """
    Basic ResNet building block
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += self.shortcut(residual)
        out = self.relu(out)
        
        return out


class Encoder(nn.Module):
    """
    ResNet-based encoder
    """
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(features[0]),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.encoder1 = ResNetBlock(features[0], features[0])
        self.encoder2 = ResNetBlock(features[0], features[1], stride=2)
        self.encoder3 = ResNetBlock(features[1], features[2], stride=2)
        self.encoder4 = ResNetBlock(features[2], features[3], stride=2)
        
        self.skip_connections = []

    def forward(self, x):
        self.skip_connections = []
        
        x = self.initial(x)
        self.skip_connections.append(x)
        
        x = self.encoder1(x)
        self.skip_connections.append(x)
        
        x = self.encoder2(x)
        self.skip_connections.append(x)
        
        x = self.encoder3(x)
        self.skip_connections.append(x)
        
        x = self.encoder4(x)
        
        return x, self.skip_connections


class DecoderBlock(nn.Module):
    """
    U-Net decoder block with skip connections
    """
    def __init__(self, in_channels, skip_channels, out_channels):
        super().__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels // 2 + skip_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip):
        x = self.upsample(x)
        
        # Resize skip connection if needed
        if x.shape[2:] != skip.shape[2:]:
            skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear', align_corners=True)
            
        x = torch.cat([x, skip], dim=1)
        x = self.conv(x)
        return x


class TransUNet(nn.Module):
    """
    TransUNet for image reconstruction
    
    Architecture:
    - ResNet-based CNN encoder
    - Transformer blocks for global context modeling
    - U-Net style decoder with skip connections
    """
    def __init__(self, img_size=224, in_channels=3, out_channels=3, patch_size=16, 
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., 
                 encoder_features=[64, 128, 256, 512], decoder_features=[256, 128, 64, 32]):
        super().__init__()
        
        # ResNet Encoder
        self.encoder = Encoder(in_channels, encoder_features)
        
        # Calculate size after encoder (after multiple downsampling operations)
        # Initial: stride 2 + MaxPool stride 2 = 1/4
        # encoder2: stride 2 = 1/8
        # encoder3: stride 2 = 1/16
        # encoder4: stride 2 = 1/32
        self.patch_size = patch_size
        
        # Input to PatchEmbed should be the size after encoder downsampling
        patch_embed_size = img_size // 32
        
        # Transformer
        self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(
            img_size=patch_embed_size, 
            patch_size=1,  # Use patch_size=1 for already downsampled features
            in_chans=encoder_features[-1], 
            embed_dim=embed_dim
        )
        
        # Update number of patches based on feature map size
        self.num_patches = self.patch_embed.num_patches
        
        # Position embedding for transformer
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, 
                qkv_bias=True, drop=0., attn_drop=0., drop_path=0.1 if i > depth - 4 else 0.,
                norm_layer=nn.LayerNorm) 
            for i in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Decoder from transformer output to image
        h, w = patch_embed_size, patch_embed_size  # Size after encoder
        self.transformer_to_cnn = nn.Sequential(
            nn.Linear(embed_dim, encoder_features[-1]),
            Rearrange('b (h w) c -> b c h w', h=h, w=w)
        )
        
        # U-Net Decoder with skip connections
        encoder_channels_reversed = list(reversed(encoder_features))
        skip_channels = encoder_channels_reversed[1:] + [encoder_features[0]]
        
        decoder_channels = decoder_features
        if len(decoder_channels) < len(skip_channels):
            decoder_channels = decoder_channels + [decoder_channels[-1]] * (len(skip_channels) - len(decoder_channels))
        
        # Decoder blocks
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(
                in_channels=encoder_channels_reversed[i] if i == 0 else decoder_channels[i-1],
                skip_channels=skip_channels[i],
                out_channels=decoder_channels[i]
            )
            for i in range(len(skip_channels))
        ])
        
        # Final layer
        self.final_conv = nn.Conv2d(decoder_channels[-1], out_channels, kernel_size=1)
        
        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        # Initialize transformer weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_transformer_weights)
    
    def _init_transformer_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_transformer(self, x):
        # Get patches
        x = self.patch_embed(x)
        
        # Add class token and position embeddings
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Add position embedding
        x = x + self.pos_embed
        
        # Apply transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        
        # Remove class token for reconstruction
        x = x[:, 1:]
        
        # Transform back to feature map
        x = self.transformer_to_cnn(x)
        
        return x

    def forward(self, x):
        # Store original input size for final resizing
        input_size = x.shape[2:]
        
        # ResNet encoder
        encoded, skip_connections = self.encoder(x)
        
        # Transformer modeling for global context
        transformed = self.forward_transformer(encoded)
        
        # First skip is from the transformer output
        skip_connections = list(reversed(skip_connections))
        x = transformed
        
        # Decoder with skip connections
        for i, decoder_block in enumerate(self.decoder_blocks):
            x = decoder_block(x, skip_connections[i])
        
        # Final convolution
        x = self.final_conv(x)
        
        # Ensure output size matches input size
        if x.shape[2:] != input_size:
            x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
        
        return x


class SSIMLoss(nn.Module):
    """
    Structural Similarity Index (SSIM) loss for image reconstruction
    """
    def __init__(self, alpha=0.05):
        super().__init__()
        self.alpha = alpha
        self.mse = nn.MSELoss()

    def forward(self, x, y):
        # MSE Loss component
        mse_loss = self.mse(x, y)
        
        # SSIM component
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2
        
        mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        mu_y = F.avg_pool2d(y, kernel_size=3, stride=1, padding=1)
        
        mu_x_sq = mu_x ** 2
        mu_y_sq = mu_y ** 2
        mu_xy = mu_x * mu_y
        
        sigma_x_sq = F.avg_pool2d(x ** 2, kernel_size=3, stride=1, padding=1) - mu_x_sq
        sigma_y_sq = F.avg_pool2d(y ** 2, kernel_size=3, stride=1, padding=1) - mu_y_sq
        sigma_xy = F.avg_pool2d(x * y, kernel_size=3, stride=1, padding=1) - mu_xy
        
        SSIM = ((2 * mu_xy + C1) * (2 * sigma_xy + C2)) / ((mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2))
        ssim_loss = 1 - SSIM.mean()
        
        # Combined loss
        loss = self.alpha * mse_loss + (1 - self.alpha) * ssim_loss
        
        return loss