import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from typing import List, Dict, Tuple, Any
from collections import Counter
import random

class ImmuneDataProcessor:
    def __init__(self, data_path: str, max_len: int = 150, random_seed: int = 42):
        self.data_path = data_path
        self.max_len = max_len
        self.random_seed = random_seed
        
        # 设置随机种子
        random.seed(random_seed)
        np.random.seed(random_seed)
        
        # 氨基酸词汇表 + 特殊token
        self.amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 
                           'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
        self.special_tokens = ['<PAD>', '<UNK>', '<SEP>', '<EOS>']
        
        self.vocab = self.special_tokens + self.amino_acids
        self.vocab_size = len(self.vocab)
        self.token_to_id = {token: i for i, token in enumerate(self.vocab)}
        self.id_to_token = {i: token for i, token in enumerate(self.vocab)}
        
        self.pad_id = self.token_to_id['<PAD>']
        self.unk_id = self.token_to_id['<UNK>']
        self.sep_id = self.token_to_id['<SEP>']
        self.eos_id = self.token_to_id['<EOS>']
        
        print(f"Vocabulary size: {self.vocab_size}")
        print(f"Special tokens: PAD={self.pad_id}, UNK={self.unk_id}, SEP={self.sep_id}, EOS={self.eos_id}")

    def clean_sequence(self, seq: str) -> str:
        """清理序列，只保留有效氨基酸"""
        if pd.isna(seq):
            return ""
        seq = str(seq).upper().strip()
        return ''.join([c for c in seq if c in self.amino_acids])

    def sequence_to_ids(self, seq: str) -> List[int]:
        """将序列转换为ID列表"""
        return [self.token_to_id.get(c, self.unk_id) for c in seq]

    def ids_to_sequence(self, ids: List[int]) -> str:
        """将ID列表转换为序列"""
        return ''.join([self.id_to_token[i] for i in ids if i != self.pad_id])

    def create_classification_input(self, pep: str, mhc: str, tcr: str, task_type: str) -> Tuple[List[int], List[int]]:
        """创建分类任务输入"""
        if task_type == 'PT':
            # <PEP><SEP><TCR>
            combined = pep + '<SEP>' + tcr
        elif task_type == 'PMT':
            # <PEP><SEP><MHC><SEP><TCR>
            combined = pep + '<SEP>' + mhc + '<SEP>' + tcr
        elif task_type == 'PM':
            # <PEP><SEP><MHC>
            combined = pep + '<SEP>' + mhc
        else:
            raise ValueError(f"Unknown task type: {task_type}")
        
        # 转换为IDs
        token_ids = []
        i = 0
        while i < len(combined):
            if combined[i:i+5] == '<SEP>':
                token_ids.append(self.sep_id)
                i += 5
            else:
                token_ids.append(self.token_to_id.get(combined[i], self.unk_id))
                i += 1
        
        # 截断
        if len(token_ids) > self.max_len:
            token_ids = token_ids[:self.max_len]
        
        # 创建attention mask
        attention_mask = [1] * len(token_ids)
        
        # Padding
        while len(token_ids) < self.max_len:
            token_ids.append(self.pad_id)
            attention_mask.append(0)
        
        return token_ids, attention_mask

    def create_generation_input(self, context: str, target: str) -> Tuple[List[int], List[int], List[int]]:
        """创建生成任务输入 - 修正标签对齐，避免双重右移"""
        # Context序列转换
        context_tokens = []
        i = 0
        while i < len(context):
            if context[i:i+5] == '<SEP>':
                context_tokens.append(self.sep_id)
                i += 5
            else:
                context_tokens.append(self.token_to_id.get(context[i], self.unk_id))
                i += 1

        # Target序列转换
        target_tokens = self.sequence_to_ids(target)

        # 动态计算可用的目标长度
        available_space = self.max_len - len(context_tokens) - 10  # 留10个buffer
        max_target_len = min(len(target_tokens) + 5, available_space)  # 给目标序列留一些余量

        # 确保最小目标长度
        if max_target_len < 10:
            # 如果空间太小，截断context
            target_space = min(len(target_tokens) + 5, 30)  # 目标序列最多30
            context_tokens = context_tokens[:self.max_len - target_space - 5]
            max_target_len = target_space

        # 截断目标序列（静默处理，不输出警告）
        if len(target_tokens) > max_target_len - 1:  # -1 为EOS留空间
            target_tokens = target_tokens[:max_target_len - 1]

        # 添加EOS
        target_with_eos = target_tokens + [self.eos_id]

        # 合并序列
        combined_tokens = context_tokens + target_with_eos
        context_len = len(context_tokens)

        # 创建attention mask
        attention_mask = [1] * len(combined_tokens)

        # 修正：创建target_for_loss，不要预先右移
        target_for_loss = [self.pad_id] * context_len + target_with_eos  # 不再预先右移

        # Padding到max_len
        while len(combined_tokens) < self.max_len:
            combined_tokens.append(self.pad_id)
            attention_mask.append(0)
            target_for_loss.append(self.pad_id)

        # 确保长度一致
        if len(target_for_loss) > self.max_len:
            target_for_loss = target_for_loss[:self.max_len]
        while len(target_for_loss) < self.max_len:
            target_for_loss.append(self.pad_id)

        return combined_tokens, attention_mask, target_for_loss

    def load_and_process_data(self) -> pd.DataFrame:
        """加载和处理数据"""
        print(f"Loading data from {self.data_path}")
        df = pd.read_csv(self.data_path)
        
        print(f"Original data: {df.shape[0]} samples")
        
        # 列名映射
        column_mapping = {}
        if 'CDR3' in df.columns:
            column_mapping['CDR3'] = 'cdr3'
        if 'MHC' in df.columns:
            column_mapping['MHC'] = 'mhc'
        if 'Epitope' in df.columns:
            column_mapping['Epitope'] = 'epitope'
        
        if column_mapping:
            df = df.rename(columns=column_mapping)
            print(f"Mapped columns: {list(column_mapping.keys())} -> {list(column_mapping.values())}")
        
        # 检查必需列
        required_cols = ['cdr3', 'mhc', 'epitope']
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing columns: {missing_cols}")
        
        # 清理序列
        df['cdr3_clean'] = df['cdr3'].apply(self.clean_sequence)
        df['mhc_clean'] = df['mhc'].apply(self.clean_sequence)
        df['epitope_clean'] = df['epitope'].apply(self.clean_sequence)
        
        # 过滤空序列
        before_filter = len(df)
        df = df[(df['cdr3_clean'].str.len() > 0) & 
                (df['mhc_clean'].str.len() > 0) & 
                (df['epitope_clean'].str.len() > 0)]
        if before_filter > len(df):
            print(f"Removed {before_filter - len(df)} samples with empty sequences")
        
        # 宽松的长度过滤
        before_filter = len(df)
        df = df[(df['cdr3_clean'].str.len() <= 35) &   
                (df['mhc_clean'].str.len() <= 80) & 
                (df['epitope_clean'].str.len() <= 30)]
        if before_filter > len(df):
            print(f"Removed {before_filter - len(df)} samples with extreme lengths")
        
        # 去重
        before_dedup = len(df)
        df = df.drop_duplicates(subset=['cdr3_clean', 'mhc_clean', 'epitope_clean'])
        if before_dedup > len(df):
            print(f"Removed {before_dedup - len(df)} duplicate samples")
        
        # 添加标签
        df['label'] = 1
        
        # 序列长度统计
        print(f"Final data: {len(df)} unique positive samples")
        print(f"Length ranges - TCR: {df['cdr3_clean'].str.len().min()}-{df['cdr3_clean'].str.len().max()}, "
              f"Peptide: {df['epitope_clean'].str.len().min()}-{df['epitope_clean'].str.len().max()}, "
              f"MHC: {df['mhc_clean'].str.len().min()}-{df['mhc_clean'].str.len().max()}")
        
        return df

    def augment_positive_samples(self, positive_df: pd.DataFrame, target_count: int) -> pd.DataFrame:
        """扩增正样本到目标数量"""
        current_count = len(positive_df)
        if current_count >= target_count:
            return positive_df.sample(n=target_count, random_state=self.random_seed).reset_index(drop=True)
        
        # 计算需要扩增的倍数
        augmentation_factor = target_count // current_count
        remainder = target_count % current_count
        
        augmented_samples = []
        
        # 完整复制
        for _ in range(augmentation_factor):
            augmented_samples.append(positive_df[['epitope_clean', 'mhc_clean', 'cdr3_clean', 'label']])
        
        # 剩余的随机采样
        if remainder > 0:
            additional = positive_df.sample(n=remainder, random_state=self.random_seed)[['epitope_clean', 'mhc_clean', 'cdr3_clean', 'label']]
            augmented_samples.append(additional)
        
        # 合并
        augmented_df = pd.concat(augmented_samples, ignore_index=True)
        
        return augmented_df

    def create_balanced_dataset(self, df: pd.DataFrame, negative_ratio: float = 1.0) -> pd.DataFrame:
        """创建平衡数据集 - 兼容原接口，内部实现1:10扩增"""
        positive_df = df[df['label'] == 1].copy()
        original_positive_count = len(positive_df)
        
        print(f"Creating balanced dataset with 1:10 expansion strategy")
        print(f"Starting with {original_positive_count} positive samples")
        
        # 获取唯一序列
        unique_peps = positive_df['epitope_clean'].unique().tolist()
        unique_mhcs = positive_df['mhc_clean'].unique().tolist()
        unique_tcrs = positive_df['cdr3_clean'].unique().tolist()
        
        print(f"Available for negative generation: {len(unique_peps)} peptides, "
              f"{len(unique_mhcs)} MHCs, {len(unique_tcrs)} TCRs")
        
        # 步骤1：为每个正样本生成10个负样本
        negative_expansion_factor = 10
        target_negative_count = original_positive_count * negative_expansion_factor
        
        negative_samples = []
        existing_combinations = set()
        
        # 记录现有的正样本组合
        for _, row in positive_df.iterrows():
            combo = (row['epitope_clean'], row['mhc_clean'], row['cdr3_clean'])
            existing_combinations.add(combo)
        
        print(f"Generating {target_negative_count} negative samples...")
        
        attempts = 0
        max_attempts = target_negative_count * 20
        
        while len(negative_samples) < target_negative_count and attempts < max_attempts:
            pep = random.choice(unique_peps)
            mhc = random.choice(unique_mhcs)
            tcr = random.choice(unique_tcrs)
            combo = (pep, mhc, tcr)
            
            if combo not in existing_combinations:
                negative_samples.append({
                    'epitope_clean': pep,
                    'mhc_clean': mhc,
                    'cdr3_clean': tcr,
                    'label': 0
                })
                existing_combinations.add(combo)
            
            attempts += 1
        
        print(f"Generated {len(negative_samples)} negative samples")
        
        # 步骤2：扩增正样本到与负样本数量相等
        target_positive_count = len(negative_samples)
        print(f"Expanding positive samples to {target_positive_count}")
        
        augmented_positive_df = self.augment_positive_samples(positive_df, target_positive_count)
        
        # 步骤3：合并数据
        negative_df = pd.DataFrame(negative_samples)
        
        # 确保数量完全匹配
        final_count = min(len(augmented_positive_df), len(negative_df))
        if len(augmented_positive_df) > final_count:
            augmented_positive_df = augmented_positive_df.sample(n=final_count, random_state=self.random_seed)
        if len(negative_df) > final_count:
            negative_df = negative_df.sample(n=final_count, random_state=self.random_seed)
        
        # 合并并打乱
        combined_df = pd.concat([augmented_positive_df, negative_df], ignore_index=True)
        combined_df = combined_df.sample(frac=1, random_state=self.random_seed).reset_index(drop=True)
        
        final_positive = len(combined_df[combined_df['label'] == 1])
        final_negative = len(combined_df[combined_df['label'] == 0])
        
        print(f"Final balanced dataset: {len(combined_df)} samples")
        print(f"  - Positive: {final_positive} (expansion: {final_positive/original_positive_count:.1f}x)")
        print(f"  - Negative: {final_negative} (expansion: {final_negative/original_positive_count:.1f}x)")
        print(f"  - Balance ratio: {final_positive/final_negative:.3f}")
        
        return combined_df

    def create_five_task_dataset(self, df: pd.DataFrame) -> List[Dict]:
        """创建五任务数据集"""
        dataset = []
        truncation_stats = {'tcr': 0, 'pep': 0}
        
        for _, row in df.iterrows():
            pep = row['epitope_clean']
            mhc = row['mhc_clean']
            tcr = row['cdr3_clean']
            label = int(row['label'])
            
            sample = {
                'pep_seq': pep,
                'mhc_seq': mhc,
                'tcr_seq': tcr,
                'label': label
            }
            
            # 分类任务输入
            sample['pt_input'], sample['pt_mask'] = self.create_classification_input(pep, mhc, tcr, 'PT')
            sample['pmt_input'], sample['pmt_mask'] = self.create_classification_input(pep, mhc, tcr, 'PMT')
            sample['pm_input'], sample['pm_mask'] = self.create_classification_input(pep, mhc, tcr, 'PM')
            
            # 生成任务输入（只对正样本）
            if label == 1:
                # 检查是否会被截断
                original_tcr_len = len(tcr)
                original_pep_len = len(pep)
                
                # TCR生成: <PEP><SEP><MHC><SEP> → TCR<EOS>
                tcr_context = pep + '<SEP>' + mhc + '<SEP>'
                sample['tcr_gen_input'], sample['tcr_gen_mask'], sample['tcr_gen_target'] = \
                    self.create_generation_input(tcr_context, tcr)
                
                # PEP生成: <MHC><SEP><TCR><SEP> → PEP<EOS>
                pep_context = mhc + '<SEP>' + tcr + '<SEP>'
                sample['pep_gen_input'], sample['pep_gen_mask'], sample['pep_gen_target'] = \
                    self.create_generation_input(pep_context, pep)
                
                # 统计截断情况（静默）
                if len([x for x in sample['tcr_gen_target'] if x != 0]) < original_tcr_len + 1:
                    truncation_stats['tcr'] += 1
                if len([x for x in sample['pep_gen_target'] if x != 0]) < original_pep_len + 1:
                    truncation_stats['pep'] += 1
            
            dataset.append(sample)
        
        # 统计信息
        gen_count = len([s for s in dataset if 'tcr_gen_input' in s])
        
        print(f"Created {len(dataset)} samples, {gen_count} with generation tasks")
        if truncation_stats['tcr'] > 0 or truncation_stats['pep'] > 0:
            print(f"Sequence adaptation: TCR sequences: {truncation_stats['tcr']}, "
                  f"Peptide sequences: {truncation_stats['pep']}")
        
        return dataset

class MultiTaskDataset(Dataset):
    def __init__(self, data: List[Dict]):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    """自定义collate函数"""
    result = {}
    
    # 分类任务数据
    for task in ['pt', 'pmt', 'pm']:
        result[f'{task}_input'] = torch.tensor([item[f'{task}_input'] for item in batch], dtype=torch.long)
        result[f'{task}_mask'] = torch.tensor([item[f'{task}_mask'] for item in batch], dtype=torch.long)
    
    result['labels'] = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    
    # 生成任务数据
    positive_indices = [i for i, item in enumerate(batch) 
                       if item['label'] == 1 and 'tcr_gen_input' in item]
    
    if positive_indices:
        # TCR生成
        tcr_gen_inputs = [batch[i]['tcr_gen_input'] for i in positive_indices]
        tcr_gen_masks = [batch[i]['tcr_gen_mask'] for i in positive_indices]
        tcr_gen_targets = [batch[i]['tcr_gen_target'] for i in positive_indices]
        
        result['tcr_gen_input'] = torch.tensor(tcr_gen_inputs, dtype=torch.long)
        result['tcr_gen_mask'] = torch.tensor(tcr_gen_masks, dtype=torch.long)
        result['tcr_gen_target'] = torch.tensor(tcr_gen_targets, dtype=torch.long)
        
        # PEP生成
        pep_gen_inputs = [batch[i]['pep_gen_input'] for i in positive_indices]
        pep_gen_masks = [batch[i]['pep_gen_mask'] for i in positive_indices]
        pep_gen_targets = [batch[i]['pep_gen_target'] for i in positive_indices]
        
        result['pep_gen_input'] = torch.tensor(pep_gen_inputs, dtype=torch.long)
        result['pep_gen_mask'] = torch.tensor(pep_gen_masks, dtype=torch.long)
        result['pep_gen_target'] = torch.tensor(pep_gen_targets, dtype=torch.long)
    
    result['positive_indices'] = torch.tensor(positive_indices, dtype=torch.long)
    
    return result

def create_cv_splits(df: pd.DataFrame, n_splits: int = 5, random_state: int = 42):
    """创建交叉验证分割"""
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    return list(skf.split(df, df['label']))

if __name__ == "__main__":
    # 测试数据处理
    processor = ImmuneDataProcessor("data.csv")
    df = processor.load_and_process_data()
    df_balanced = processor.create_balanced_dataset(df, negative_ratio=1.0)
    dataset = processor.create_five_task_dataset(df_balanced)
    
    print(f"\nFinal dataset: {len(dataset)} samples")
    
    # 检查样本
    sample = dataset[0]
    print(f"Sample keys: {sample.keys()}")
    
    if 'tcr_gen_input' in sample:
        input_len = len([x for x in sample['tcr_gen_input'] if x != 0])
        target_len = len([x for x in sample['tcr_gen_target'] if x != 0])
        print(f"Generation sample - input length: {input_len}, target length: {target_len}")