import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support, roc_auc_score,
                           confusion_matrix, classification_report, roc_curve, auc)
import os
import json
import argparse
from typing import Dict, List, Any
import matplotlib.pyplot as plt
import seaborn as sns
from accelerate import Accelerator
import warnings
warnings.filterwarnings('ignore')

from data import ImmuneDataProcessor, MultiTaskDataset, collate_fn, create_cv_splits
from model import MultiTaskImmuneModel

class MultiTaskTester:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        
        # 初始化Accelerator
        self.accelerator = Accelerator()
        self.device = self.accelerator.device
        
        # 数据处理器
        self.data_processor = ImmuneDataProcessor(
            data_path=config['data_path'],
            max_len=config.get('max_len', 150),
            random_seed=config['seed']
        )
        
        # 创建反向词汇表
        self.vocab_dict = {v: k for k, v in self.data_processor.token_to_id.items()}
        
        # 损失函数
        self.classification_loss = nn.CrossEntropyLoss()
        self.generation_loss = nn.CrossEntropyLoss(ignore_index=0)
        
        # 存储测试结果
        self.test_results = []

    def load_model(self, model_path: str) -> MultiTaskImmuneModel:
        """加载训练好的模型"""
        model = MultiTaskImmuneModel(
            vocab_size=self.data_processor.vocab_size,
            d_model=self.config.get('d_model', 512),
            max_len=self.config.get('max_len', 150),
            n_encoder_layers=self.config.get('n_encoder_layers', 6),
            n_decoder_layers=self.config.get('n_decoder_layers', 4),
            n_heads=self.config.get('n_heads', 8),
            dropout=self.config.get('dropout', 0.1),
            vocab_dict=self.vocab_dict
        )
        
        # 加载权重
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        model = model.to(self.device)
        model.eval()
        
        return model

    def compute_comprehensive_classification_metrics(self, logits: torch.Tensor, labels: torch.Tensor, 
                                                   confidence: torch.Tensor, task_name: str) -> Dict[str, float]:
        """计算全面的分类指标"""
        with torch.no_grad():
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
            labels_cpu = labels.cpu().numpy()
            preds_cpu = preds.cpu().numpy()
            probs_cpu = probs.cpu().numpy()
            confidence_cpu = confidence.squeeze().cpu().numpy()
            
            # 基本分类指标
            acc = accuracy_score(labels_cpu, preds_cpu)
            precision, recall, f1, support = precision_recall_fscore_support(
                labels_cpu, preds_cpu, average=None, zero_division=0
            )
            
            # 宏平均和微平均
            precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
                labels_cpu, preds_cpu, average='macro', zero_division=0
            )
            precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
                labels_cpu, preds_cpu, average='micro', zero_division=0
            )
            
            # ROC AUC
            try:
                auc_score = roc_auc_score(labels_cpu, probs_cpu[:, 1])
                fpr, tpr, _ = roc_curve(labels_cpu, probs_cpu[:, 1])
                auc_manual = auc(fpr, tpr)
            except:
                auc_score = 0.0
                auc_manual = 0.0
            
            # 混淆矩阵
            cm = confusion_matrix(labels_cpu, preds_cpu)
            tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
            
            # 特异性和敏感性
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
            
            # 平衡准确率
            balanced_accuracy = (sensitivity + specificity) / 2
            
            # Matthews correlation coefficient
            mcc_denom = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
            mcc = ((tp * tn) - (fp * fn)) / mcc_denom if mcc_denom != 0 else 0
            
            return {
                f'{task_name}_accuracy': acc,
                f'{task_name}_precision_class_0': precision[0] if len(precision) > 0 else 0,
                f'{task_name}_precision_class_1': precision[1] if len(precision) > 1 else 0,
                f'{task_name}_recall_class_0': recall[0] if len(recall) > 0 else 0,
                f'{task_name}_recall_class_1': recall[1] if len(recall) > 1 else 0,
                f'{task_name}_f1_class_0': f1[0] if len(f1) > 0 else 0,
                f'{task_name}_f1_class_1': f1[1] if len(f1) > 1 else 0,
                f'{task_name}_precision_macro': precision_macro,
                f'{task_name}_recall_macro': recall_macro,
                f'{task_name}_f1_macro': f1_macro,
                f'{task_name}_precision_micro': precision_micro,
                f'{task_name}_recall_micro': recall_micro,
                f'{task_name}_f1_micro': f1_micro,
                f'{task_name}_auc': auc_score,
                f'{task_name}_auc_manual': auc_manual,
                f'{task_name}_specificity': specificity,
                f'{task_name}_sensitivity': sensitivity,
                f'{task_name}_balanced_accuracy': balanced_accuracy,
                f'{task_name}_mcc': mcc,
                f'{task_name}_avg_confidence': float(confidence_cpu.mean()),
                f'{task_name}_std_confidence': float(confidence_cpu.std()),
                f'{task_name}_true_positive': int(tp),
                f'{task_name}_true_negative': int(tn),
                f'{task_name}_false_positive': int(fp),
                f'{task_name}_false_negative': int(fn),
                f'{task_name}_support_class_0': int(support[0]) if len(support) > 0 else 0,
                f'{task_name}_support_class_1': int(support[1]) if len(support) > 1 else 0,
            }

    def compute_comprehensive_generation_metrics(self, logits: torch.Tensor, targets: torch.Tensor, 
                                               task_name: str) -> Dict[str, float]:
        """计算全面的生成指标"""
        with torch.no_grad():
            # Shift for next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = targets[..., 1:].contiguous()
            
            # Flatten
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = shift_labels.view(-1)
            
            # 计算损失
            loss = self.generation_loss(shift_logits, shift_labels)
            perplexity = torch.exp(loss)
            
            # Token级别准确率
            preds = torch.argmax(shift_logits, dim=-1)
            mask = shift_labels != 0  # ignore padding
            
            if mask.sum() > 0:
                correct = (preds == shift_labels) & mask
                token_accuracy = correct.sum().float() / mask.sum().float()
                
                # 各个token位置的准确率
                seq_len = targets.size(1) - 1  # exclude first token
                pos_accuracies = []
                for pos in range(min(seq_len, 20)):  # 前20个位置
                    pos_mask = mask.view(-1, seq_len)[:, pos] if pos < seq_len else torch.zeros_like(mask[:seq_len])
                    if pos_mask.sum() > 0:
                        pos_correct = (preds.view(-1, seq_len)[:, pos] == shift_labels.view(-1, seq_len)[:, pos]) & pos_mask
                        pos_acc = pos_correct.sum().float() / pos_mask.sum().float()
                        pos_accuracies.append(pos_acc.item())
                    else:
                        pos_accuracies.append(0.0)
            else:
                token_accuracy = torch.tensor(0.0)
                pos_accuracies = [0.0] * 20
            
            # 序列级别指标
            batch_size = targets.size(0)
            seq_accuracies = []
            for b in range(batch_size):
                seq_targets = targets[b, 1:]  # remove first token
                seq_preds = torch.argmax(logits[b, :-1, :], dim=-1)
                seq_mask = seq_targets != 0
                if seq_mask.sum() > 0:
                    seq_correct = (seq_preds == seq_targets) & seq_mask
                    seq_acc = seq_correct.sum().float() / seq_mask.sum().float()
                    seq_accuracies.append(seq_acc.item())
            
            perfect_sequences = sum(1 for acc in seq_accuracies if acc == 1.0)
            avg_seq_accuracy = np.mean(seq_accuracies) if seq_accuracies else 0.0
            
            # 构建结果字典
            result = {
                f'{task_name}_loss': loss.item(),
                f'{task_name}_perplexity': perplexity.item(),
                f'{task_name}_token_accuracy': token_accuracy.item(),
                f'{task_name}_sequence_accuracy': avg_seq_accuracy,
                f'{task_name}_perfect_sequences': perfect_sequences,
                f'{task_name}_total_sequences': len(seq_accuracies),
                f'{task_name}_perfect_sequence_ratio': perfect_sequences / max(len(seq_accuracies), 1),
            }
            
            # 添加位置准确率
            for i, pos_acc in enumerate(pos_accuracies):
                result[f'{task_name}_pos_{i}_accuracy'] = pos_acc
            
            return result

    def test_model_comprehensive(self, model: MultiTaskImmuneModel, test_loader: DataLoader, 
                               fold: int, stage: int) -> Dict[str, float]:
        """全面测试模型"""
        model.eval()
        
        all_outputs = {
            'pt_logits': [], 'pt_labels': [], 'pt_confidence': [],
            'pmt_logits': [], 'pmt_labels': [], 'pmt_confidence': [],
            'pm_logits': [], 'pm_labels': [], 'pm_confidence': [],
            'tcr_gen_logits': [], 'tcr_gen_targets': [],
            'pep_gen_logits': [], 'pep_gen_targets': []
        }
        
        with torch.no_grad():
            for batch in test_loader:
                outputs = model(batch)
                
                # 收集分类结果
                for task in ['pt', 'pmt', 'pm']:
                    if f'{task}_logits' in outputs:
                        all_outputs[f'{task}_logits'].append(outputs[f'{task}_logits'].cpu())
                        all_outputs[f'{task}_labels'].append(batch['labels'].cpu())
                        all_outputs[f'{task}_confidence'].append(outputs[f'{task}_confidence'].cpu())
                
                # 收集生成结果
                if 'tcr_gen_logits' in outputs:
                    all_outputs['tcr_gen_logits'].append(outputs['tcr_gen_logits'].cpu())
                    all_outputs['tcr_gen_targets'].append(outputs['tcr_gen_targets'].cpu())
                
                if 'pep_gen_logits' in outputs:
                    all_outputs['pep_gen_logits'].append(outputs['pep_gen_logits'].cpu())
                    all_outputs['pep_gen_targets'].append(outputs['pep_gen_targets'].cpu())
        
        # 合并所有结果
        merged_outputs = {}
        for key, value_list in all_outputs.items():
            if value_list:
                merged_outputs[key] = torch.cat(value_list, dim=0)
        
        # 计算全面指标
        comprehensive_metrics = {
            'fold': fold,
            'stage': stage,
            'phase': 'test'
        }
        
        # 分类任务指标
        for task in ['pt', 'pmt', 'pm']:
            if f'{task}_logits' in merged_outputs:
                task_metrics = self.compute_comprehensive_classification_metrics(
                    merged_outputs[f'{task}_logits'],
                    merged_outputs[f'{task}_labels'],
                    merged_outputs[f'{task}_confidence'],
                    task
                )
                comprehensive_metrics.update(task_metrics)
        
        # 生成任务指标
        if 'tcr_gen_logits' in merged_outputs:
            tcr_metrics = self.compute_comprehensive_generation_metrics(
                merged_outputs['tcr_gen_logits'],
                merged_outputs['tcr_gen_targets'],
                'tcr_gen'
            )
            comprehensive_metrics.update(tcr_metrics)
        
        if 'pep_gen_logits' in merged_outputs:
            pep_metrics = self.compute_comprehensive_generation_metrics(
                merged_outputs['pep_gen_logits'],
                merged_outputs['pep_gen_targets'],
                'pep_gen'
            )
            comprehensive_metrics.update(pep_metrics)
        
        return comprehensive_metrics

    def test_all_folds(self):
        """测试所有fold和所有阶段"""
        self.accelerator.print("Starting Comprehensive Testing")
        self.accelerator.print("=" * 50)
        
        # 加载数据
        df = self.data_processor.load_and_process_data()
        df_balanced = self.data_processor.create_balanced_dataset(df, negative_ratio=1.0)
        dataset = self.data_processor.create_five_task_dataset(df_balanced)
        
        # 创建交叉验证分割
        cv_splits = create_cv_splits(df_balanced, n_splits=self.config.get('n_folds', 5))
        
        all_test_results = []
        
        for fold, (train_indices, test_indices) in enumerate(cv_splits):
            self.accelerator.print(f"\nTesting Fold {fold}")
            
            # 准备测试数据
            test_data = [dataset[i] for i in test_indices]
            test_dataset = MultiTaskDataset(test_data)
            
            test_loader = DataLoader(
                test_dataset,
                batch_size=self.config.get('batch_size', 16),
                shuffle=False,
                collate_fn=collate_fn,
                num_workers=self.config.get('num_workers', 4)
            )
            
            test_loader = self.accelerator.prepare(test_loader)
            
            # 测试每个阶段的模型
            for stage in [1, 2, 3]:
                model_path = os.path.join(
                    self.config.get('output_dir', 'outputs'),
                    f'best_model_stage_{stage}_fold_{fold}.pt'
                )
                
                if os.path.exists(model_path):
                    self.accelerator.print(f"Testing Stage {stage} model: {model_path}")
                    
                    # 加载模型
                    model = self.load_model(model_path)
                    model = self.accelerator.prepare(model)
                    
                    # 测试模型
                    test_metrics = self.test_model_comprehensive(model, test_loader, fold, stage)
                    all_test_results.append(test_metrics)
                    
                    self.accelerator.print(f"Stage {stage} completed")
                else:
                    self.accelerator.print(f"Model not found: {model_path}")
        
        # 保存详细结果
        if self.accelerator.is_main_process:
            # 保存原始结果
            results_df = pd.DataFrame(all_test_results)
            results_path = os.path.join(self.config.get('output_dir', 'outputs'), 'comprehensive_test_results.csv')
            results_df.to_csv(results_path, index=False)
            
            # 生成汇总统计
            self.generate_summary_statistics(results_df)
            
            self.accelerator.print(f"Comprehensive test results saved to {results_path}")
    
    def generate_summary_statistics(self, results_df: pd.DataFrame):
        """生成汇总统计"""
        summary_stats = []
        
        # 按任务和阶段分组
        for stage in [1, 2, 3]:
            stage_data = results_df[results_df['stage'] == stage]
            
            if len(stage_data) == 0:
                continue
                
            stage_summary = {'stage': stage}
            
            # 分类任务统计
            for task in ['pt', 'pmt', 'pm']:
                for metric in ['accuracy', 'f1_macro', 'auc', 'balanced_accuracy', 'mcc']:
                    col_name = f'{task}_{metric}'
                    if col_name in stage_data.columns:
                        values = stage_data[col_name].dropna()
                        if len(values) > 0:
                            stage_summary[f'{col_name}_mean'] = values.mean()
                            stage_summary[f'{col_name}_std'] = values.std()
                            stage_summary[f'{col_name}_min'] = values.min()
                            stage_summary[f'{col_name}_max'] = values.max()
            
            # 生成任务统计
            for task in ['tcr_gen', 'pep_gen']:
                for metric in ['token_accuracy', 'sequence_accuracy', 'perfect_sequence_ratio', 'perplexity']:
                    col_name = f'{task}_{metric}'
                    if col_name in stage_data.columns:
                        values = stage_data[col_name].dropna()
                        if len(values) > 0:
                            stage_summary[f'{col_name}_mean'] = values.mean()
                            stage_summary[f'{col_name}_std'] = values.std()
                            stage_summary[f'{col_name}_min'] = values.min()
                            stage_summary[f'{col_name}_max'] = values.max()
            
            summary_stats.append(stage_summary)
        
        # 保存汇总统计
        summary_df = pd.DataFrame(summary_stats)
        summary_path = os.path.join(self.config.get('output_dir', 'outputs'), 'test_summary_statistics.csv')
        summary_df.to_csv(summary_path, index=False)
        
        # 打印重要指标
        self.accelerator.print("\n" + "="*60)
        self.accelerator.print("SUMMARY STATISTICS")
        self.accelerator.print("="*60)
        
        for _, row in summary_df.iterrows():
            stage = int(row['stage'])
            self.accelerator.print(f"\nStage {stage}:")
            
            # 分类任务
            if stage == 1:
                for task in ['pt', 'pmt', 'pm']:
                    acc_col = f'{task}_accuracy_mean'
                    f1_col = f'{task}_f1_macro_mean'
                    if acc_col in row and pd.notna(row[acc_col]):
                        acc_mean = row[acc_col]
                        acc_std = row.get(f'{task}_accuracy_std', 0)
                        f1_mean = row.get(f1_col, 0)
                        f1_std = row.get(f'{task}_f1_macro_std', 0)
                        self.accelerator.print(f"  {task.upper()}: Acc={acc_mean:.4f}±{acc_std:.4f}, F1={f1_mean:.4f}±{f1_std:.4f}")
            
            # 生成任务
            elif stage == 2:
                tcr_acc_col = 'tcr_gen_token_accuracy_mean'
                if tcr_acc_col in row and pd.notna(row[tcr_acc_col]):
                    tcr_acc = row[tcr_acc_col]
                    tcr_std = row.get('tcr_gen_token_accuracy_std', 0)
                    tcr_seq = row.get('tcr_gen_sequence_accuracy_mean', 0)
                    self.accelerator.print(f"  TCR Gen: Token Acc={tcr_acc:.4f}±{tcr_std:.4f}, Seq Acc={tcr_seq:.4f}")
            
            elif stage == 3:
                pep_acc_col = 'pep_gen_token_accuracy_mean'
                if pep_acc_col in row and pd.notna(row[pep_acc_col]):
                    pep_acc = row[pep_acc_col]
                    pep_std = row.get('pep_gen_token_accuracy_std', 0)
                    pep_seq = row.get('pep_gen_sequence_accuracy_mean', 0)
                    self.accelerator.print(f"  PEP Gen: Token Acc={pep_acc:.4f}±{pep_std:.4f}, Seq Acc={pep_seq:.4f}")

def main():
    parser = argparse.ArgumentParser(description='Comprehensive Multi-Task Testing')
    parser.add_argument('--data_path', type=str, default='data.csv', help='Data file path')
    parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory with trained models')
    parser.add_argument('--batch_size', type=int, default=16, help='Batch size for testing')
    
    args = parser.parse_args()
    
    config = {
        'data_path': args.data_path,
        'output_dir': args.output_dir,
        'seed': 42,
        'n_folds': 5,
        'batch_size': args.batch_size,
        'max_len': 120,
        'd_model': 512,
        'n_encoder_layers': 6,
        'n_decoder_layers': 4,
        'n_heads': 8,
        'dropout': 0.1,
        'num_workers': 4,
    }
    
    tester = MultiTaskTester(config)
    tester.test_all_folds()

if __name__ == "__main__":
    main()