#!/usr/bin/env python3
"""
annDNA Model Definition
Unified model for seq, struct, full
"""

import torch
import torch.nn as nn


class annDNA(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, max_seq_len=1002):
        super().__init__()

        self.d_model = d_model
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.nhead = nhead
        self.num_layers = num_layers

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output
        self.mlm_head = nn.Linear(d_model, vocab_size)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

        # Initialize
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape

        # Position IDs
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)

        # Embeddings
        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.pos_embedding(pos_ids)
        embeddings = self.layer_norm(self.dropout(token_embeds + pos_embeds))

        # Attention mask
        mask = (attention_mask == 0) if attention_mask is not None else None

        # Transformer
        hidden_states = self.transformer(embeddings, src_key_padding_mask=mask)

        # MLM head
        logits = self.mlm_head(hidden_states)

        return logits

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def forward_with_attention(self, input_ids, attention_mask=None):
        """Forward pass with attention extraction"""
        batch_size, seq_len = input_ids.shape

        # Position IDs
        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)

        # Embeddings
        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.pos_embedding(pos_ids)
        embeddings = self.layer_norm(self.dropout(token_embeds + pos_embeds))

        # Attention mask
        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        # Manual forward through each layer to extract attention
        hidden_states = embeddings
        all_attentions = []

        for layer in self.transformer.layers:
            # Self-attention with attention weights
            attn_output, attn_weights = layer.self_attn(
                query=hidden_states,
                key=hidden_states,
                value=hidden_states,
                key_padding_mask=key_padding_mask,
                need_weights=True,
                average_attn_weights=False
            )

            all_attentions.append(attn_weights)

            # Residual connection and layer norm
            hidden_states = layer.norm1(hidden_states + layer.dropout1(attn_output))

            # Feed forward
            ff_output = layer.linear2(layer.dropout(layer.activation(layer.linear1(hidden_states))))
            hidden_states = layer.norm2(hidden_states + layer.dropout2(ff_output))

        # MLM head
        logits = self.mlm_head(hidden_states)

        return logits, all_attentions

    def get_attention(self, input_ids, attention_mask=None, layer_idx=None, head_idx=None):
        """
        Extract attention scores

        Args:
            layer_idx: None (all layers), int (specific layer), -1 (last layer)
            head_idx: None (all heads), int (specific head)

        Returns:
            numpy array of attention scores
        """
        self.eval()
        with torch.no_grad():
            logits, attentions = self.forward_with_attention(input_ids, attention_mask)

        all_attentions = torch.stack(attentions, dim=0)
        all_attentions = all_attentions.cpu().numpy()

        if layer_idx is not None:
            if layer_idx == -1:
                layer_idx = self.num_layers - 1
            all_attentions = all_attentions[layer_idx]

        if head_idx is not None:
            all_attentions = all_attentions[..., head_idx, :, :]

        return all_attentions

    def get_model_info(self):
        return {
            'vocab_size': self.vocab_size,
            'd_model': self.d_model,
            'nhead': self.nhead,
            'num_layers': self.num_layers,
            'max_seq_len': self.max_seq_len,
            'num_parameters': f"{self.get_num_params():,}",
            'num_parameters_M': f"{self.get_num_params()/1e6:.1f}M"
        }
