import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import argparse
from typing import List, Dict, Tuple, Optional
import os
import json

from data import ImmuneDataProcessor
from model import MultiTaskImmuneModel

class MultiTaskInference:
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device('cpu')  # 使用CPU避免CUDA问题
        print(f"Using device: {self.device}")
        
        # 数据处理器
        self.data_processor = ImmuneDataProcessor(
            data_path="dummy.csv",
            max_len=config.get('max_len', 150),
            random_seed=42
        )
        
        # 创建反向词汇表
        self.vocab_dict = {v: k for k, v in self.data_processor.token_to_id.items()}
        
        # 加载模型
        self.models = {}
        self.load_all_models()

    def load_all_models(self):
        """加载所有阶段的模型"""
        model_dir = self.config.get('model_dir', 'outputs')
        fold = self.config.get('fold', 0)
        
        for stage in [1, 2, 3]:
            model_path = os.path.join(model_dir, f'best_model_stage_{stage}_fold_{fold}.pt')
            
            if os.path.exists(model_path):
                print(f"Loading Stage {stage} model from {model_path}")
                
                model = MultiTaskImmuneModel(
                    vocab_size=self.data_processor.vocab_size,
                    d_model=self.config.get('d_model', 512),
                    max_len=self.config.get('max_len', 150),
                    n_encoder_layers=self.config.get('n_encoder_layers', 6),
                    n_decoder_layers=self.config.get('n_decoder_layers', 4),
                    n_heads=self.config.get('n_heads', 8),
                    dropout=self.config.get('dropout', 0.1),
                    vocab_dict=self.vocab_dict
                )
                
                model.load_state_dict(torch.load(model_path, map_location=self.device))
                model = model.to(self.device)
                model.eval()
                
                self.models[stage] = model
                print(f"Stage {stage} model loaded successfully")
            else:
                print(f"Stage {stage} model not found: {model_path}")

    def clean_and_validate_sequence(self, seq: str, seq_type: str) -> str:
        """清理和验证输入序列"""
        if not seq:
            raise ValueError(f"{seq_type} sequence cannot be empty")
        
        cleaned = self.data_processor.clean_sequence(seq)
        if not cleaned:
            raise ValueError(f"Invalid {seq_type} sequence: {seq}")
        
        return cleaned

    def prepare_generation_input(self, context: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """准备生成任务输入"""
        # 转换context为token IDs
        context_tokens = []
        i = 0
        while i < len(context):
            if context[i:i+5] == '<SEP>':
                context_tokens.append(self.data_processor.sep_id)
                i += 5
            else:
                token_id = self.data_processor.token_to_id.get(context[i], self.data_processor.unk_id)
                if token_id >= self.data_processor.vocab_size:
                    token_id = self.data_processor.unk_id
                context_tokens.append(token_id)
                i += 1
        
        # 检查所有token ID是否在有效范围内
        for i, token_id in enumerate(context_tokens):
            if token_id < 0 or token_id >= self.data_processor.vocab_size:
                context_tokens[i] = self.data_processor.unk_id
        
        context_tensor = torch.tensor([context_tokens], dtype=torch.long).to(self.device)
        mask_tensor = torch.ones_like(context_tensor, dtype=torch.long).to(self.device)
        
        print(f"Context length: {len(context_tokens)}")
        return context_tensor, mask_tensor

    def improve_sampling_strategy(self, logits: torch.Tensor, temperature: float = 1.0, 
                                  top_k: int = 10, verbose: bool = False) -> torch.Tensor:
        """简化的采样策略"""
        
        # 应用温度
        if temperature != 1.0:
            logits = logits / temperature
        
        # 确保logits是有限的
        logits = torch.where(torch.isfinite(logits), logits, torch.tensor(-1e9, dtype=logits.dtype, device=logits.device))
        
        # Top-k采样
        if top_k > 0 and top_k < logits.size(-1):
            top_k = min(top_k, logits.size(-1))
            values, indices = torch.topk(logits, top_k)
            
            # 创建新的logits，只保留top-k
            new_logits = torch.full_like(logits, -1e9)
            new_logits.scatter_(-1, indices, values)
            logits = new_logits
        
        # 计算概率
        probs = torch.softmax(logits, dim=-1)
        
        # 只在verbose模式下打印top概率
        if verbose:
            top_probs, top_indices = torch.topk(probs, 3)
            print(f"Top 3: " + ", ".join([
                f"{self.vocab_dict.get(top_indices[i].item(), '<UNK>')}({top_probs[i].item():.3f})"
                for i in range(3)
            ]))
        
        # 确保概率是有限的且和为1
        probs = torch.where(torch.isfinite(probs), probs, torch.tensor(0.0, dtype=probs.dtype, device=probs.device))
        probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-8)
        
        # 采样
        try:
            next_token = torch.multinomial(probs.unsqueeze(0), 1)
            return next_token.squeeze(0)
        except RuntimeError as e:
            print(f"Sampling error: {e}, using argmax")
            return torch.argmax(probs, dim=-1, keepdim=True)

    def generate_single_sequence(self, model_generator, context_tensor: torch.Tensor, 
                                max_length: int = 50, temperature: float = 0.8, 
                                top_k: int = 10, verbose: bool = False) -> List[int]:
        """生成单个序列 - 遇到EOS就停止"""
        generated_tokens = []
        current_input = context_tensor.clone()

        if verbose:
            print(f"Generating sequence: max_length={max_length}, temperature={temperature}, top_k={top_k}")

        # 确保top_k不超过词汇表大小
        effective_top_k = min(top_k, self.data_processor.vocab_size)

        with torch.no_grad():
            for step in range(max_length):
                if verbose:
                    print(f"Step {step + 1}/{max_length}")

                try:
                    # 创建attention mask
                    attention_mask = torch.ones_like(current_input)

                    # 使用简化的生成器调用
                    logits = model_generator(current_input, attention_mask)

                    # 取最后一个位置的logits
                    next_token_logits = logits[0, -1, :]

                    # 简化的采样策略
                    next_token = self.improve_sampling_strategy(
                        next_token_logits, 
                        temperature=temperature,
                        top_k=effective_top_k,
                        verbose=verbose
                    )

                    token_id = next_token.item()
                    token_str = self.vocab_dict.get(token_id, '<UNK>')

                    if verbose:
                        print(f"Generated: {token_str}")

                    # 检查token ID是否有效
                    if token_id < 0 or token_id >= self.data_processor.vocab_size:
                        print(f"Invalid token ID: {token_id}")
                        break

                    # 遇到EOS立即停止
                    if token_id == self.data_processor.eos_id:
                        if verbose:
                            print("Generated EOS, stopping")
                        break

                    # 跳过其他特殊token
                    if token_id in [self.data_processor.pad_id, self.data_processor.unk_id, self.data_processor.sep_id]:
                        if verbose:
                            print(f"Skipping special token: {token_str}")
                        continue

                    # 添加到生成序列
                    generated_tokens.append(token_id)

                    # 确保next_token是正确的形状
                    if next_token.dim() == 1:
                        next_token = next_token.unsqueeze(0)
                    elif next_token.dim() == 0:
                        next_token = next_token.unsqueeze(0).unsqueeze(0)

                    # 更新输入
                    current_input = torch.cat([current_input, next_token], dim=1)

                except Exception as e:
                    print(f"Error in step {step}: {str(e)}")
                    break

        if verbose:
            print(f"Generated {len(generated_tokens)} tokens")
        return generated_tokens

    def generate_multiple_candidates(self, model_generator, context_tensor: torch.Tensor,
                                   num_candidates: int = 1, max_length: int = 50,
                                   temperature: float = 0.8, top_k: int = 10, 
                                   verbose: bool = False) -> List[List[int]]:
        """生成多个候选序列"""
        candidates = []
        
        for i in range(num_candidates):
            if verbose:
                print(f"\n=== Generating candidate {i+1}/{num_candidates} ===")
            
            candidate_tokens = self.generate_single_sequence(
                model_generator, context_tensor, max_length, temperature, top_k, verbose
            )
            candidates.append(candidate_tokens)
        
        return candidates

    def decode_sequence(self, token_ids: List[int]) -> str:
        """解码序列，自动在EOS处截断"""
        decoded_chars = []
        for token_id in token_ids:
            # 遇到EOS就停止解码
            if token_id == self.data_processor.eos_id:
                break
            # 跳过特殊token
            if token_id in [self.data_processor.pad_id, self.data_processor.unk_id, self.data_processor.sep_id]:
                continue
            # 添加普通amino acid
            decoded_chars.append(self.vocab_dict.get(token_id, '<UNK>'))
        
        return ''.join(decoded_chars)

    def generate_tcr(self, pep: str, mhc: str, num_candidates: int = 1, max_length: int = 50, 
                    temperature: float = 0.8, top_k: int = 10, verbose: bool = False) -> Dict:
        """生成TCR序列"""
        if 2 not in self.models:
            raise ValueError("TCR generation model (Stage 2) not loaded")

        model = self.models[2]

        try:
            # 清理输入序列
            pep = self.clean_and_validate_sequence(pep, 'peptide')
            mhc = self.clean_and_validate_sequence(mhc, 'mhc')

            print(f"Generating {num_candidates} TCR candidates for peptide: {pep}, MHC: {mhc}")

            # 创建context: <PEP><SEP><MHC><SEP>
            context = pep + '<SEP>' + mhc + '<SEP>'
            context_tensor, mask_tensor = self.prepare_generation_input(context)

            # 生成多个候选序列
            candidate_tokens_list = self.generate_multiple_candidates(
                model.tcr_generator, 
                context_tensor,
                num_candidates=num_candidates,
                max_length=max_length,
                temperature=temperature,
                top_k=top_k,
                verbose=verbose
            )

            # 解码所有候选序列
            candidates = []
            for i, candidate_tokens in enumerate(candidate_tokens_list):
                # 使用新的解码方法，自动在EOS处截断
                generated_sequence = self.decode_sequence(candidate_tokens)
                
                candidates.append({
                    'sequence': generated_sequence,
                    'token_count': len(candidate_tokens),
                    'final_length': len(generated_sequence)
                })
                
                print(f"Candidate {i+1}: {generated_sequence}")

            return {
                'task': 'TCR_Generation',
                'peptide': pep,
                'mhc': mhc,
                'num_candidates': num_candidates,
                'candidates': candidates,
                'context': context,
                'parameters': {
                    'max_length': max_length,
                    'temperature': temperature,
                    'top_k': top_k
                }
            }

        except Exception as e:
            import traceback
            traceback.print_exc()
            return {'error': str(e)}

    def generate_peptide(self, mhc: str, tcr: str, num_candidates: int = 1, max_length: int = 50, 
                        temperature: float = 0.8, top_k: int = 10, verbose: bool = False) -> Dict:
        """生成Peptide序列"""
        if 3 not in self.models:
            raise ValueError("Peptide generation model (Stage 3) not loaded")
        
        model = self.models[3]
        
        try:
            # 清理输入序列
            mhc = self.clean_and_validate_sequence(mhc, 'mhc')
            tcr = self.clean_and_validate_sequence(tcr, 'tcr')
            
            print(f"Generating {num_candidates} peptide candidates for MHC: {mhc}, TCR: {tcr}")
            
            # 创建context: <MHC><SEP><TCR><SEP>
            context = mhc + '<SEP>' + tcr + '<SEP>'
            context_tensor, mask_tensor = self.prepare_generation_input(context)
            
            # 生成多个候选序列
            candidate_tokens_list = self.generate_multiple_candidates(
                model.pep_generator, 
                context_tensor,
                num_candidates=num_candidates,
                max_length=max_length,
                temperature=temperature,
                top_k=top_k,
                verbose=verbose
            )
            
            # 解码所有候选序列
            candidates = []
            for i, candidate_tokens in enumerate(candidate_tokens_list):
                # 使用新的解码方法，自动在EOS处截断
                generated_sequence = self.decode_sequence(candidate_tokens)
                
                candidates.append({
                    'sequence': generated_sequence,
                    'token_count': len(candidate_tokens),
                    'final_length': len(generated_sequence)
                })
                
                print(f"Candidate {i+1}: {generated_sequence}")
            
            return {
                'task': 'Peptide_Generation',
                'mhc': mhc,
                'tcr': tcr,
                'num_candidates': num_candidates,
                'candidates': candidates,
                'context': context,
                'parameters': {
                    'max_length': max_length,
                    'temperature': temperature,
                    'top_k': top_k
                }
            }
            
        except Exception as e:
            import traceback
            traceback.print_exc()
            return {'error': str(e)}

    def predict_binding(self, pep: str, mhc: str = "", tcr: str = "", task: str = "PMT") -> Dict:
        """预测结合可能性"""
        if 1 not in self.models:
            raise ValueError("Classification model (Stage 1) not loaded")
        
        model = self.models[1]
        
        try:
            # 清理序列
            pep = self.clean_and_validate_sequence(pep, 'peptide')
            mhc = self.clean_and_validate_sequence(mhc, 'mhc') if mhc else ''
            tcr = self.clean_and_validate_sequence(tcr, 'tcr') if tcr else ''
            
            # 创建输入
            input_ids, attention_mask = self.data_processor.create_classification_input(pep, mhc, tcr, task.upper())
            
            # 转换为tensor
            input_tensor = torch.tensor([input_ids], dtype=torch.long).to(self.device)
            mask_tensor = torch.tensor([attention_mask], dtype=torch.long).to(self.device)
            
            # 预测
            with torch.no_grad():
                encoder_out = model.encoder(input_tensor, mask_tensor)
                
                if task.upper() == 'PT':
                    cls_out = model.pt_classifier(encoder_out, mask_tensor)
                    disc_out = model.discriminator(cls_out['pooled'])
                elif task.upper() == 'PMT':
                    cls_out = model.pmt_classifier(encoder_out, mask_tensor)
                    disc_out = model.discriminator(cls_out['pooled'])
                elif task.upper() == 'PM':
                    cls_out = model.pm_classifier(encoder_out, mask_tensor)
                    disc_out = model.discriminator(cls_out['pooled'])
                else:
                    raise ValueError(f"Unknown task: {task}")
                
                # 处理输出
                probs = torch.softmax(cls_out['logits'], dim=-1)
                confidence = cls_out['confidence']
                discriminator_score = disc_out
                
                binding_prob = probs[0, 1].item()
                confidence_score = confidence[0].item()
                discriminator_score = discriminator_score[0].item()
                
                prediction = "Binding" if binding_prob > 0.5 else "Non-binding"
                
                return {
                    'task': task.upper(),
                    'peptide': pep,
                    'mhc': mhc,
                    'tcr': tcr,
                    'prediction': prediction,
                    'binding_probability': binding_prob,
                    'confidence_score': confidence_score,
                    'discriminator_score': discriminator_score,
                    'class_probabilities': {
                        'non_binding': probs[0, 0].item(),
                        'binding': probs[0, 1].item()
                    }
                }
                
        except Exception as e:
            print(f"Prediction error: {str(e)}")
            return {'error': str(e)}

    def batch_inference(self, input_data: List[Dict], task: str) -> List[Dict]:
        """批量推理"""
        results = []
        
        for item in input_data:
            try:
                if task.upper() in ['PT', 'PMT', 'PM']:
                    result = self.predict_binding(
                        pep=item.get('peptide', ''),
                        mhc=item.get('mhc', ''),
                        tcr=item.get('tcr', ''),
                        task=task
                    )
                elif task.upper() == 'TCR_GEN':
                    result = self.generate_tcr(
                        pep=item.get('peptide', ''),
                        mhc=item.get('mhc', ''),
                        num_candidates=item.get('num_candidates', 1),
                        max_length=item.get('max_length', 50),
                        temperature=item.get('temperature', 0.8),
                        top_k=item.get('top_k', 10)
                    )
                elif task.upper() == 'PEP_GEN':
                    result = self.generate_peptide(
                        mhc=item.get('mhc', ''),
                        tcr=item.get('tcr', ''),
                        num_candidates=item.get('num_candidates', 1),
                        max_length=item.get('max_length', 50),
                        temperature=item.get('temperature', 0.8),
                        top_k=item.get('top_k', 10)
                    )
                else:
                    result = {'error': f'Unknown task: {task}'}
                
                results.append(result)
                
            except Exception as e:
                results.append({'error': str(e), 'input': item})
        
        return results

def main():
    parser = argparse.ArgumentParser(description='Multi-Task Immune Inference')
    parser.add_argument('--task', type=str, required=True, 
                       choices=['PT', 'PMT', 'PM', 'TCR_GEN', 'PEP_GEN'],
                       help='Task type')
    parser.add_argument('--model_dir', type=str, default='outputs', help='Directory containing trained models')
    parser.add_argument('--fold', type=int, default=0, help='Which fold model to use')
    
    # 单个推理参数
    parser.add_argument('--peptide', type=str, help='Peptide sequence')
    parser.add_argument('--mhc', type=str, help='MHC sequence')
    parser.add_argument('--tcr', type=str, help='TCR sequence')
    
    # 批量推理参数
    parser.add_argument('--input_file', type=str, help='Input CSV file for batch inference')
    parser.add_argument('--output_file', type=str, help='Output file for batch inference')
    
    # 生成参数
    parser.add_argument('--num_candidates', type=int, default=1, help='Number of candidates to generate')
    parser.add_argument('--max_length', type=int, default=50, help='Maximum generation length')
    parser.add_argument('--temperature', type=float, default=0.8, help='Generation temperature')
    parser.add_argument('--top_k', type=int, default=5, help='Top-k sampling')
    parser.add_argument('--verbose', action='store_true', help='Verbose output during generation')
    
    args = parser.parse_args()
    
    # 配置
    config = {
        'model_dir': args.model_dir,
        'fold': args.fold,
        'max_len': 120,
        'd_model': 512,
        'n_encoder_layers': 6,
        'n_decoder_layers': 4,
        'n_heads': 8,
        'dropout': 0.1,
    }
    
    # 创建推理器
    inferencer = MultiTaskInference(config)
    
    if args.input_file:
        # 批量推理
        print(f"Performing batch inference with task: {args.task}")
        input_df = pd.read_csv(args.input_file)
        input_data = input_df.to_dict('records')
        
        results = inferencer.batch_inference(input_data, args.task)
        
        # 保存结果
        results_df = pd.DataFrame(results)
        output_file = args.output_file or f'inference_results_{args.task.lower()}.csv'
        results_df.to_csv(output_file, index=False)
        print(f"Results saved to {output_file}")
        
    else:
        # 单个推理
        if args.task in ['PT', 'PMT', 'PM']:
            if not args.peptide:
                print("Error: --peptide is required for classification tasks")
                return
            
            result = inferencer.predict_binding(
                pep=args.peptide,
                mhc=args.mhc or "",
                tcr=args.tcr or "",
                task=args.task
            )
            
        elif args.task == 'TCR_GEN':
            if not args.peptide or not args.mhc:
                print("Error: --peptide and --mhc are required for TCR generation")
                return
            
            result = inferencer.generate_tcr(
                pep=args.peptide,
                mhc=args.mhc,
                num_candidates=args.num_candidates,
                max_length=args.max_length,
                temperature=args.temperature,
                top_k=args.top_k,
                verbose=args.verbose
            )
            
        elif args.task == 'PEP_GEN':
            if not args.mhc or not args.tcr:
                print("Error: --mhc and --tcr are required for peptide generation")
                return
            
            result = inferencer.generate_peptide(
                mhc=args.mhc,
                tcr=args.tcr,
                num_candidates=args.num_candidates,
                max_length=args.max_length,
                temperature=args.temperature,
                top_k=args.top_k,
                verbose=args.verbose
            )
        
        # 打印结果
        print("\nInference Result:")
        print("=" * 50)
        for key, value in result.items():
            if key == 'candidates' and isinstance(value, list):
                print(f"{key}:")
                for i, candidate in enumerate(value):
                    print(f"  Candidate {i+1}: {candidate}")
            else:
                print(f"{key}: {value}")

if __name__ == "__main__":
    main()