from typing import Any, Dict, List, Optional, Tuple, NamedTuple
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from gvp_transformer_encoder_without2d import GVPTransformerEncoder
from transformer_src.transformer_decoder import TransformerDecoder
from util import rotate, CoordBatchConverter, load_structure

class MultiLayerCrossAttention(nn.Module):
    def __init__(self, num_layers=2, dim=512, num_heads=8, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.MultiheadAttention(dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(dim) for _ in range(num_layers)
        ])
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout) for _ in range(num_layers)
        ])
        self.k_proj = nn.Linear(dim, dim)
        self.q_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        self.feedforward = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(dim),
                nn.Linear(dim, dim*4),
                nn.GELU(),
                nn.Linear(dim*4, dim)
            ) for _ in range(num_layers)
        ])

    def forward(self, x, rna_padding_mask=None, protein_padding_mask=None):
        # Combine padding masks
        combined_padding_mask = None
        if protein_padding_mask is not None or rna_padding_mask is not None:
            protein_len = protein_padding_mask.shape[1] if protein_padding_mask is not None else 0
            rna_len = rna_padding_mask.shape[1] if rna_padding_mask is not None else 0
            combined_padding_mask = torch.zeros(
                (x.shape[0], protein_len + 1 + rna_len), 
                dtype=torch.bool, 
                device=x.device
            )
            if protein_padding_mask is not None:
                combined_padding_mask[:, :protein_len] = protein_padding_mask
            if rna_padding_mask is not None:
                combined_padding_mask[:, protein_len+1:] = rna_padding_mask

        x = x.transpose(0, 1)  # Convert to seq_len, batch, dim
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        for layer, norm, dropout, ffn in zip(self.layers, self.layer_norms, self.dropouts, self.feedforward):
            x = norm(x)
            attn_output, _ = layer(q, k, v, key_padding_mask=combined_padding_mask)
            x = x + dropout(attn_output)
            x = x + dropout(ffn(x))

        return x.transpose(0, 1)  # Convert back to batch, seq_len, dim

class CARD(nn.Module):

    def __init__(self, args, alphabet, pro_dim = 1280):
        super().__init__()
        encoder_embed_tokens = self.build_embedding(
            args,
            alphabet,
            args.encoder_embed_dim,
        )
        decoder_embed_tokens = self.build_embedding(
            args,
            alphabet,
            args.decoder_embed_dim,
        )
        encoder = self.build_encoder(args, alphabet, encoder_embed_tokens)
        decoder = self.build_decoder(args, alphabet, decoder_embed_tokens)
        self.args = args
        self.encoder = encoder
        self.decoder = decoder
        self.protein_repr_projection = nn.Linear(pro_dim, 512)
        self.corss_attention = MultiLayerCrossAttention(num_layers=args.attn_layer, dim=512, num_heads=8, dropout=0.1)
        self.contrastive_head = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512)
        )

    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        encoder = GVPTransformerEncoder(args, src_dict, embed_tokens)
        return encoder

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        decoder = TransformerDecoder(
            args,
            tgt_dict,
            embed_tokens,
        )
        return decoder

    @classmethod
    def build_embedding(cls, args, dictionary, embed_dim):
        num_embeddings = len(dictionary)
        padding_idx = dictionary.padding_idx
        emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
        nn.init.normal_(emb.weight, mean=0, std=embed_dim**-0.5)
        nn.init.constant_(emb.weight[padding_idx], 0)
        return emb

    def forward(self, batch_coords, 
                confidence, padding_mask, 
                protein_embeddings_local, protein_embeddings_global, 
                protein_padding_mask, target_seqs=None, use_protein=True, 
                temperature=0.00001,need_embedding=False):
        # Encoder
        encoder_out = self.encoder(
            batch_coords[:, :, [0, 1, 2], :],  # Backbone atoms
            batch_coords[:, :, :, :],          # All atoms for dihedral angles
            padding_mask,
            confidence
        )

        if use_protein:
            rna_embeddings = encoder_out['encoder_out'][0]  # (L+2, bz, 512)
            rna_embeddings = rna_embeddings.transpose(0, 1)  # (bz, L+2, dim)
            protein_embeddings = torch.cat([
                protein_embeddings_local,  # (bz, 64, 1280)
                protein_embeddings_global  # (bz, 1, 1280)
            ], dim=-2)  # (bz, 65, 1280)

            protein_embeddings = self.protein_repr_projection(protein_embeddings)  # (bz, 65, 512)
            x = torch.cat([protein_embeddings, rna_embeddings], dim=1)  # (bz, L+2 + 65, 512)

            attn_output = self.corss_attention(x, padding_mask, protein_padding_mask)[:,self.args.protein_len:]
            encoder_out['encoder_out'][0] = attn_output.transpose(0, 1)

        if target_seqs is not None:
            sampled_tokens = torch.full((batch_coords.shape[0], batch_coords.shape[1]-1), self.decoder.dictionary.get_idx("<mask>"), dtype=int, device=batch_coords.device)
            sampled_tokens[:, 0] = self.decoder.dictionary.get_idx("<cath>")
            sampled_tokens[:, 1:] = target_seqs
            logits, _ = self.decoder(
                sampled_tokens,  # Pass the tokens generated so far
                encoder_out=encoder_out,
                incremental_state=None
            )
            logits = logits.transpose(1, 2)
            logits = logits[:, :-1, :]
        else:
            logits_list = [] 
            incremental_state = dict()
            sampled_tokens = torch.full((batch_coords.shape[0], batch_coords.shape[1]), self.decoder.dictionary.get_idx("<mask>"), dtype=int, device=batch_coords.device)
            sampled_tokens[:, 0] = self.decoder.dictionary.get_idx("<cath>")
            for i in range(1, batch_coords.shape[1]-1):
                logits, _ = self.decoder(
                    sampled_tokens[:, :i],  # Pass the tokens generated so far
                    encoder_out=encoder_out,
                    incremental_state=incremental_state,
                )
                logits = logits.transpose(1, 2).squeeze(1)
                logits_list.append(logits)

                probs = F.softmax(logits/temperature, dim=-1)
                sampled_tokens[:, i+1] = torch.multinomial(probs, 1).squeeze(-1)

            logits = torch.stack(logits_list, dim=1)  # (bz, L, vocab_size)

        if need_embedding:
            rna_output = encoder_out['encoder_out'][0].transpose(0, 1)# (bz, L, 512)
            rna_padding_mask = padding_mask # (bz, L)

            valid_mask = ~rna_padding_mask
            valid_mask_expanded = valid_mask.unsqueeze(-1).to(rna_output.dtype)  # (bz, L, 1)
            rna_output_masked = rna_output * valid_mask_expanded  # Apply mask
            
            sum_output = torch.sum(rna_output_masked, dim=1)  # (bz, 512)
            counts = torch.sum(valid_mask, dim=1, keepdim=True).to(rna_output.dtype)  # (bz, 1)
            counts = counts.clamp(min=1e-9)  # Avoid division by zero
            
            rna_global_embedding = sum_output / counts
            
            rna_global_embedding = torch.mean(rna_output, dim=1) 
            rna_global_embedding = self.contrastive_head(rna_global_embedding)
            return logits, rna_global_embedding
        
        return logits