import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# --------------------------------------------------------
# 2D Sin-Cos Position Embedding
# --------------------------------------------------------
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    # use half of dimensions to encode grid_w
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def _trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",)

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _trunc_normal_(tensor, mean, std, a, b)


class DistanceEncoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(1, embed_dim // 4),
            nn.ReLU(),
            nn.Linear(embed_dim // 4, embed_dim)
        )
    
    def forward(self, distances):
        return self.encoder(distances)


class CrossAttentionBlock(nn.Module):
    def __init__(self, embed_dim=1024, num_heads=8):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, patch_features, retrieved_features):
        attn_output, _ = self.cross_attn(query=patch_features, key=retrieved_features, value=retrieved_features)
        output = self.norm(patch_features + attn_output)
        return output

class SelfAttentionBlock(nn.Module):
    def __init__(self, embed_dim=1024, num_heads=8, num_patches=256):

        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

        grid_size = int(num_patches**0.5)
        pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=True)

        self.register_buffer("pos_embedding", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
        
    def forward(self, features):
        features_with_pos = features + self.pos_embedding
        attn_output, _ = self.self_attn(query=features_with_pos, key=features_with_pos, value=features_with_pos)
        output = self.norm(features + attn_output)
        return output

class ProcessingStream(nn.Module):
    def __init__(self, embed_dim=1024, num_heads=8, num_patches=256, depth=4):
        super().__init__()
        self.cross_attention_layers = nn.ModuleList([
            CrossAttentionBlock(embed_dim, num_heads) for _ in range(depth)
        ])
        self.self_attention_layers = nn.ModuleList([
            SelfAttentionBlock(embed_dim, num_heads, num_patches=num_patches) for _ in range(depth)
        ])
        
        self.dist_embed_dim = 1024
        self.distance_encoder = DistanceEncoder(self.dist_embed_dim)
        self.fusion_projection = nn.Linear(embed_dim + self.dist_embed_dim, embed_dim)

        self.num_patches = num_patches
        self.depth = depth

    def forward(self, features, retrieved_features, distances):
        B, N_ret, K, E = retrieved_features.shape
        dist_reshaped = distances.unsqueeze(-1)
        distance_embeddings = self.distance_encoder(dist_reshaped)  # Shape: (B, N, K, dist_embed_dim)

        concatenated_features = torch.cat([retrieved_features, distance_embeddings], dim=-1) # Shape: (B, N, K, E + dist_embed_dim)
        
        fused_kv_features = self.fusion_projection(concatenated_features) # Shape: (B, N, K, E)
        
        kv = fused_kv_features.reshape(B * self.num_patches, K, E)
        
        x = features # Shape: (B, 1+N, E)
        
        for i in range(self.depth):
            cls_token = x[:, :1, :]
            patch_tokens = x[:, 1:, :]
            q = patch_tokens.reshape(B * self.num_patches, 1, E)
            cross_attn_output = self.cross_attention_layers[i](q, kv)
            cross_attn_patches = cross_attn_output.reshape(B, self.num_patches, E)
            x_for_self_attn = torch.cat([cls_token, cross_attn_patches], dim=1)
            x = self.self_attention_layers[i](x_for_self_attn)
            
        return x

class MainClassifier(nn.Module):
    """
     (Main Classifier).
    """
    def __init__(self, embed_dim=1024, num_patches=256, num_heads=8, depth=4, num_classes=1, args=None):
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        trunc_normal_(self.cls_token, std=.02)

        self.normal_stream = ProcessingStream(embed_dim, num_heads, num_patches, depth=depth)
        self.abnormal_stream = ProcessingStream(embed_dim, num_heads, num_patches, depth=depth)
        
        self.fusion_head = nn.Sequential(
            nn.LayerNorm(embed_dim * 2),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(embed_dim, num_classes)
        )
        
        self.apply(self._initialize_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embedding', 'cls_token'}

    def _initialize_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, cur_features, 
                cur_retrieved_features_from_normal, cur_distances_from_normal,
                cur_retrieved_features_from_abnormal, cur_distances_from_abnormal):
        B = cur_features.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        
        features_with_cls = torch.cat((cls_tokens, cur_features), dim=1)

        normal_output = self.normal_stream(features_with_cls, cur_retrieved_features_from_normal, cur_distances_from_normal)
        abnormal_output = self.abnormal_stream(features_with_cls, cur_retrieved_features_from_abnormal, cur_distances_from_abnormal)
        
        normal_cls = normal_output[:, 0]
        abnormal_cls = abnormal_output[:, 0]
        
        fused_cls = torch.cat([normal_cls, abnormal_cls], dim=-1)
        logits = self.fusion_head(fused_cls)
        
        return logits

