import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, List
import math

class InputEncoder(nn.Module):
    """简化的输入编码器：只包含token embedding和位置embedding"""
    def __init__(self, vocab_size: int, d_model: int, max_len: int = 1000):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        
        # Token嵌入
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        # 位置嵌入
        self.position_embedding = nn.Embedding(max_len, d_model)
        
        # 简单的融合层
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        """
        Args:
            tokens: [batch_size, seq_len] - 序列的 token ID
            positions: [batch_size, seq_len] - 位置 ID
        """
        token_emb = self.token_embedding(tokens)              # [B, L, d_model]
        position_emb = self.position_embedding(positions)     # [B, L, d_model]
        
        # 直接相加（标准Transformer做法）
        combined = token_emb + position_emb
        
        output = self.dropout(combined)
        output = self.layer_norm(output)
        return output

class TransformerLayer(nn.Module):
    """标准Transformer层"""
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, 
                key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention
        attn_out, _ = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Feed forward
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

class SharedEncoder(nn.Module):
    """共享编码器，用于所有任务"""
    def __init__(self, vocab_size: int, d_model: int, max_len: int, n_layers: int, 
                 n_heads: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        
        # 简化的输入编码器
        self.input_encoder = InputEncoder(vocab_size, d_model, max_len)
        
        # Transformer layers
        d_ff = d_model * 4
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = input_ids.shape
        
        # 获取位置编码
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)
        
        # 输入编码（只有token和位置）
        x = self.input_encoder(input_ids, positions)
        
        # Key padding mask (True for padded positions)
        key_padding_mask = (attention_mask == 0)
        
        # 通过Transformer层
        for layer in self.layers:
            x = layer(x, key_padding_mask=key_padding_mask)
        
        return x

class ClassificationDecoder(nn.Module):
    """分类解码器"""
    def __init__(self, d_model: int, dropout: float = 0.1):
        super().__init__()
        self.pooler = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(d_model, 2)  # binary classification
        self.confidence_head = nn.Linear(d_model, 1)  # confidence score
    
    def forward(self, encoder_output: torch.Tensor, attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
        # Pool the encoder output
        mask_expanded = attention_mask.unsqueeze(-1).float()
        sum_embeddings = torch.sum(encoder_output * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled = sum_embeddings / sum_mask
        
        pooled = torch.tanh(self.pooler(pooled))
        pooled = self.dropout(pooled)
        
        logits = self.classifier(pooled)
        confidence = torch.sigmoid(self.confidence_head(pooled))
        
        return {
            'logits': logits,
            'confidence': confidence,
            'pooled': pooled  # 用于判别器
        }

class Discriminator(nn.Module):
    """判别器：判断三元组是否能结合"""
    def __init__(self, d_model: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, d_model // 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 4, 1),
            nn.Sigmoid()
        )
    
    def forward(self, pooled_representation: torch.Tensor) -> torch.Tensor:
        return self.net(pooled_representation)

class GenerationDecoder(nn.Module):
    """生成解码器（GPT风格）- 使用标准因果掩码"""
    def __init__(self, vocab_size: int, d_model: int, max_len: int, n_layers: int, 
                 n_heads: int, dropout: float = 0.1):
        super().__init__()
        # 简化的输入编码器
        self.input_encoder = InputEncoder(vocab_size, d_model, max_len)
        
        # Transformer解码器层
        d_ff = d_model * 4
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        
        # 输出投影
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def create_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """创建标准因果掩码"""
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1)
        return mask
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """简化的前向传播 - 只使用标准因果掩码"""
        batch_size, seq_len = input_ids.shape
        device = input_ids.device
        
        # 获取位置编码
        positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # 输入编码（只有token和位置）
        x = self.input_encoder(input_ids, positions)
        
        # 创建标准因果掩码（下三角掩码）
        causal_mask = self.create_causal_mask(seq_len, device)
        
        # Key padding mask
        key_padding_mask = (attention_mask == 0)
        
        # 通过解码器层
        for layer in self.layers:
            x = layer(x, attn_mask=causal_mask, key_padding_mask=key_padding_mask)
        
        logits = self.output_projection(x)
        return logits

class MultiTaskImmuneModel(nn.Module):
    """多任务免疫模型"""
    def __init__(self, vocab_size: int = 24, d_model: int = 512, max_len: int = 150,
                 n_encoder_layers: int = 6, n_decoder_layers: int = 4, n_heads: int = 8,
                 dropout: float = 0.1, vocab_dict: Optional[Dict[int, str]] = None):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len
        self.vocab_dict = vocab_dict or {}
        
        # 共享编码器
        self.encoder = SharedEncoder(vocab_size, d_model, max_len, n_encoder_layers, n_heads, dropout)
        
        # 三个分类解码器
        self.pt_classifier = ClassificationDecoder(d_model, dropout)
        self.pmt_classifier = ClassificationDecoder(d_model, dropout)
        self.pm_classifier = ClassificationDecoder(d_model, dropout)
        
        # 判别器
        self.discriminator = Discriminator(d_model, dropout)
        
        # 两个生成解码器
        self.tcr_generator = GenerationDecoder(vocab_size, d_model, max_len, n_decoder_layers, n_heads, dropout)
        self.pep_generator = GenerationDecoder(vocab_size, d_model, max_len, n_decoder_layers, n_heads, dropout)
        
        print(f"Simplified model initialized: vocab_size={vocab_size}, d_model={d_model}")
        print(f"Using standard causal masking for generation tasks")
    
    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        outputs = {}
        
        # 三个分类任务
        for task in ['pt', 'pmt', 'pm']:
            input_ids = batch[f'{task}_input']
            attention_mask = batch[f'{task}_mask']
            
            # 编码
            encoder_out = self.encoder(input_ids, attention_mask)
            
            # 分类
            classifier = getattr(self, f'{task}_classifier')
            cls_out = classifier(encoder_out, attention_mask)
            
            outputs[f'{task}_logits'] = cls_out['logits']
            outputs[f'{task}_confidence'] = cls_out['confidence']
            outputs[f'{task}_pooled'] = cls_out['pooled']
            
            # 判别器
            outputs[f'{task}_discriminator'] = self.discriminator(cls_out['pooled'])
        
        # 生成任务（只对正样本）- 简化调用，不再需要context_length
        if 'tcr_gen_input' in batch:
            tcr_input = batch['tcr_gen_input']
            tcr_mask = batch['tcr_gen_mask']
            
            tcr_logits = self.tcr_generator(tcr_input, tcr_mask)
            outputs['tcr_gen_logits'] = tcr_logits
            outputs['tcr_gen_targets'] = batch['tcr_gen_target']
        
        if 'pep_gen_input' in batch:
            pep_input = batch['pep_gen_input']
            pep_mask = batch['pep_gen_mask']
            
            pep_logits = self.pep_generator(pep_input, pep_mask)
            outputs['pep_gen_logits'] = pep_logits
            outputs['pep_gen_targets'] = batch['pep_gen_target']
        
        return outputs

    # 移除find_context_length方法，不再需要

    def generate_sequence(self, context_ids: torch.Tensor, max_length: int = 30, 
                         generator_type: str = 'tcr', temperature: float = 1.0,
                         top_k: int = 0, top_p: float = 1.0) -> torch.Tensor:
        """序列生成（用于推理）- 改进的采样策略"""
        self.eval()
        device = context_ids.device
        batch_size = context_ids.size(0)
        
        # 创建attention mask
        context_mask = (context_ids != 0).long()
        
        generator = self.tcr_generator if generator_type == 'tcr' else self.pep_generator
        eos_id = 3  # <EOS> token id
        pad_id = 0  # <PAD> token id
        
        generated = context_ids.clone()
        generated_mask = context_mask.clone()
        context_length = context_ids.size(1)
        
        # 为每个样本维护完成状态
        done = torch.zeros(batch_size, dtype=torch.bool, device=device)
        
        with torch.no_grad():
            for step in range(max_length):
                # 生成下一个token
                logits = generator(generated, generated_mask)
                next_token_logits = logits[:, -1, :]  # 取最后一个位置的logits
                
                # 改进的采样策略
                next_token = self.sample_next_token(
                    next_token_logits, temperature=temperature, top_k=top_k, top_p=top_p
                )
                
                # 对已完成的样本，强制生成PAD
                next_token = torch.where(done, pad_id, next_token)
                
                # 添加到序列
                generated = torch.cat([generated, next_token.unsqueeze(1)], dim=1)
                generated_mask = torch.cat([generated_mask, (~done).long().unsqueeze(1)], dim=1)
                
                # 更新完成状态
                done = done | (next_token == eos_id)
                
                # 如果所有样本都完成了，提前停止
                if done.all():
                    break
        
        # 返回生成的部分（去除context）
        return generated[:, context_length:]
    
    def sample_next_token(self, logits: torch.Tensor, temperature: float = 1.0,
                         top_k: int = 0, top_p: float = 1.0) -> torch.Tensor:
        """改进的采样策略"""
        if temperature == 0.0:
            return torch.argmax(logits, dim=-1)
        
        # 应用温度
        logits = logits / temperature
        
        # Top-k采样
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))
            values, indices = torch.topk(logits, top_k, dim=-1)
            logits_filtered = torch.full_like(logits, float('-inf'))
            logits_filtered.scatter_(-1, indices, values)
            logits = logits_filtered
        
        # Top-p采样
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # 移除累积概率超过阈值的token
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
            logits = logits.masked_fill(indices_to_remove, float('-inf'))
        
        # 采样
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, 1).squeeze(-1)

def count_parameters(model):
    """计算模型参数数量"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

if __name__ == "__main__":
    # 测试模型
    vocab_dict = {i: f'token_{i}' for i in range(24)}
    model = MultiTaskImmuneModel(vocab_dict=vocab_dict)
    total, trainable = count_parameters(model)
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")
    
    # 简单的前向传播测试
    batch_size = 4
    seq_len = 100
    
    dummy_batch = {
        'pt_input': torch.randint(0, 24, (batch_size, seq_len)),
        'pt_mask': torch.ones(batch_size, seq_len),
        'pmt_input': torch.randint(0, 24, (batch_size, seq_len)),
        'pmt_mask': torch.ones(batch_size, seq_len),
        'pm_input': torch.randint(0, 24, (batch_size, seq_len)),
        'pm_mask': torch.ones(batch_size, seq_len),
        'labels': torch.randint(0, 2, (batch_size,)),
    }
    
    with torch.no_grad():
        outputs = model(dummy_batch)
        print("Model forward pass successful!")
        print(f"Output keys: {list(outputs.keys())}")