import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import os
import json
import time
import argparse
from typing import Dict, List, Any
from accelerate import Accelerator, DistributedDataParallelKwargs
import logging
import warnings
warnings.filterwarnings('ignore')

from data import ImmuneDataProcessor, MultiTaskDataset, collate_fn, create_cv_splits
from model import MultiTaskImmuneModel

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class MultiStageTrainer:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        
        # 初始化Accelerator with DDP
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        self.accelerator = Accelerator(
            gradient_accumulation_steps=config.get('gradient_accumulation_steps', 1),
            mixed_precision=config.get('mixed_precision', 'fp16'),
            log_with="tensorboard" if config.get('use_tensorboard', False) else None,
            project_dir=config.get('output_dir', 'outputs'),
            kwargs_handlers=[ddp_kwargs]
        )
        
        self.device = self.accelerator.device
        
        # 数据处理器
        self.data_processor = ImmuneDataProcessor(
            data_path=config['data_path'],
            max_len=config.get('max_len', 150),
            random_seed=config['seed']
        )
        
        # 创建反向词汇表
        self.vocab_dict = {v: k for k, v in self.data_processor.token_to_id.items()}
        
        # 损失函数
        self.classification_loss = nn.CrossEntropyLoss()
        self.generation_loss = nn.CrossEntropyLoss(ignore_index=0)  # ignore padding
        self.discriminator_loss = nn.BCELoss()
        
        # 存储指标
        self.all_metrics = []
        
        # 确保输出目录存在
        os.makedirs(config.get('output_dir', 'outputs'), exist_ok=True)
        
        # 训练阶段
        self.stage = 1  # 1: 分类+判别器, 2: TCR生成, 3: PEP生成

    def create_model(self) -> MultiTaskImmuneModel:
        """创建模型"""
        model = MultiTaskImmuneModel(
            vocab_size=self.data_processor.vocab_size,
            d_model=self.config.get('d_model', 512),
            max_len=self.config.get('max_len', 150),
            n_encoder_layers=self.config.get('n_encoder_layers', 6),
            n_decoder_layers=self.config.get('n_decoder_layers', 4),
            n_heads=self.config.get('n_heads', 8),
            dropout=self.config.get('dropout', 0.1),
            vocab_dict=self.vocab_dict
        )
        return model

    def freeze_modules(self, model, stage: int):
        """根据训练阶段冻结不同的模块"""
        # 先解冻所有参数
        for param in model.parameters():
            param.requires_grad = True
            
        if stage == 1:
            # 阶段1：训练分类器和判别器，冻结生成器
            self.accelerator.print("Stage 1: Training classifiers and discriminators")
            for param in model.tcr_generator.parameters():
                param.requires_grad = False
            for param in model.pep_generator.parameters():
                param.requires_grad = False
                
        elif stage == 2:
            # 阶段2：训练TCR生成器，冻结其他生成器和分类器
            self.accelerator.print("Stage 2: Training TCR generator")
            for param in model.pt_classifier.parameters():
                param.requires_grad = False
            for param in model.pmt_classifier.parameters():
                param.requires_grad = False
            for param in model.pm_classifier.parameters():
                param.requires_grad = False
            for param in model.discriminator.parameters():
                param.requires_grad = False
            for param in model.pep_generator.parameters():
                param.requires_grad = False
                
        elif stage == 3:
            # 阶段3：训练PEP生成器，冻结其他模块
            self.accelerator.print("Stage 3: Training PEP generator")
            for param in model.pt_classifier.parameters():
                param.requires_grad = False
            for param in model.pmt_classifier.parameters():
                param.requires_grad = False
            for param in model.pm_classifier.parameters():
                param.requires_grad = False
            for param in model.discriminator.parameters():
                param.requires_grad = False
            for param in model.tcr_generator.parameters():
                param.requires_grad = False

    def compute_classification_metrics(self, logits: torch.Tensor, labels: torch.Tensor, 
                                     confidence: torch.Tensor) -> Dict[str, float]:
        """计算分类指标"""
        with torch.no_grad():
            probs = torch.softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
            labels_cpu = labels.cpu().numpy()
            preds_cpu = preds.cpu().numpy()
            probs_cpu = probs[:, 1].cpu().numpy()
            confidence_cpu = confidence.squeeze().cpu().numpy()
            
            acc = accuracy_score(labels_cpu, preds_cpu)
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels_cpu, preds_cpu, average='binary', zero_division=0
            )
            
            try:
                auc = roc_auc_score(labels_cpu, probs_cpu)
            except:
                auc = 0.0
            
            return {
                'accuracy': acc,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'auc': auc,
                'avg_confidence': float(confidence_cpu.mean())
            }

    def compute_generation_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """计算生成损失（保持梯度）- 现在只进行一次右移"""
        # 单次右移用于next token prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = targets[..., 1:].contiguous()

        # Flatten for loss computation
        shift_logits = shift_logits.view(-1, shift_logits.size(-1))
        shift_labels = shift_labels.view(-1)

        # Compute loss (保持梯度)
        loss = self.generation_loss(shift_logits, shift_labels)
        return loss

    def compute_generation_metrics(self, logits: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
        """计算生成指标（仅用于评估，不计算梯度）"""
        with torch.no_grad():
            # 单次右移用于next token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = targets[..., 1:].contiguous()

            # Flatten for loss computation
            shift_logits = shift_logits.view(-1, shift_logits.size(-1))
            shift_labels = shift_labels.view(-1)

            # Compute loss
            loss = self.generation_loss(shift_logits, shift_labels)
            perplexity = torch.exp(loss)

            # Compute token accuracy
            preds = torch.argmax(shift_logits, dim=-1)
            mask = shift_labels != 0  # ignore padding

            if mask.sum() > 0:
                correct = (preds == shift_labels) & mask
                token_accuracy = correct.sum().float() / mask.sum().float()
            else:
                token_accuracy = torch.tensor(0.0)

            return {
                'loss': loss.item(),
                'perplexity': perplexity.item(),
                'token_accuracy': token_accuracy.item()
            }

    def compute_losses(self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor], 
                      stage: int) -> Dict[str, torch.Tensor]:
        """根据训练阶段计算相应的损失"""
        losses = {}
        
        if stage == 1:
            # 阶段1：分类损失 + 判别器损失
            for task in ['pt', 'pmt', 'pm']:
                if f'{task}_logits' in outputs:
                    cls_loss = self.classification_loss(outputs[f'{task}_logits'], batch['labels'])
                    disc_loss = self.discriminator_loss(
                        outputs[f'{task}_discriminator'].squeeze(),
                        batch['labels'].float()
                    )
                    losses[f'{task}_classification'] = cls_loss
                    losses[f'{task}_discriminator'] = disc_loss * 0.1
                    
        elif stage == 2:
            # 阶段2：TCR生成损失（保持梯度）
            if 'tcr_gen_logits' in outputs:
                gen_loss = self.compute_generation_loss(
                    outputs['tcr_gen_logits'], 
                    outputs['tcr_gen_targets']
                )
                losses['tcr_generation'] = gen_loss
                
        elif stage == 3:
            # 阶段3：PEP生成损失（保持梯度）
            if 'pep_gen_logits' in outputs:
                gen_loss = self.compute_generation_loss(
                    outputs['pep_gen_logits'], 
                    outputs['pep_gen_targets']
                )
                losses['pep_generation'] = gen_loss
        
        return losses

    def train_epoch(self, model, train_loader, optimizer, scheduler, epoch: int, fold: int, stage: int) -> Dict[str, float]:
        """训练一个epoch"""
        model.train()
        total_loss = 0
        num_batches = 0
        
        # 根据阶段记录不同的指标
        task_losses = {}
        task_metrics = {
            'pt': [], 'pmt': [], 'pm': [],
            'tcr_gen': [], 'pep_gen': []
        }
        
        for batch_idx, batch in enumerate(train_loader):
            with self.accelerator.accumulate(model):
                try:
                    outputs = model(batch)
                    losses = self.compute_losses(outputs, batch, stage)
                    
                    if not losses:  # 如果没有损失，跳过这个batch
                        continue
                        
                    # 总损失
                    total_batch_loss = sum(losses.values())
                    
                    # 检查损失是否需要梯度
                    if not total_batch_loss.requires_grad:
                        self.accelerator.print(f"Warning: loss doesn't require grad at batch {batch_idx}")
                        continue
                    
                    # 反向传播
                    self.accelerator.backward(total_batch_loss)
                    
                    # 梯度裁剪
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(model.parameters(), 1.0)
                    
                    optimizer.step()
                    if scheduler is not None:
                        scheduler.step()
                    optimizer.zero_grad()
                    
                    # 记录损失
                    for task_name, loss in losses.items():
                        if task_name not in task_losses:
                            task_losses[task_name] = 0
                        task_losses[task_name] += loss.item()
                    
                    # 计算指标
                    with torch.no_grad():
                        if stage == 1:
                            # 分类指标
                            for task in ['pt', 'pmt', 'pm']:
                                if f'{task}_logits' in outputs:
                                    metrics = self.compute_classification_metrics(
                                        outputs[f'{task}_logits'], 
                                        batch['labels'],
                                        outputs[f'{task}_confidence']
                                    )
                                    task_metrics[task].append(metrics)
                        
                        elif stage == 2 and 'tcr_gen_logits' in outputs:
                            # TCR生成指标
                            metrics = self.compute_generation_metrics(
                                outputs['tcr_gen_logits'],
                                outputs['tcr_gen_targets']
                            )
                            task_metrics['tcr_gen'].append(metrics)
                            
                        elif stage == 3 and 'pep_gen_logits' in outputs:
                            # PEP生成指标
                            metrics = self.compute_generation_metrics(
                                outputs['pep_gen_logits'],
                                outputs['pep_gen_targets']
                            )
                            task_metrics['pep_gen'].append(metrics)
                    
                    total_loss += total_batch_loss.item()
                    num_batches += 1
                    
                    if (batch_idx + 1) % self.config.get('log_interval', 50) == 0:
                        avg_loss = total_loss / max(num_batches, 1)
                        self.accelerator.print(
                            f"Stage {stage}, Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {avg_loss:.4f}"
                        )
                        
                except Exception as e:
                    self.accelerator.print(f"Error in training batch {batch_idx}: {str(e)}")
                    continue
        
        if num_batches == 0:
            return {'total_loss': float('inf')}
        
        # 汇总epoch指标
        epoch_metrics = {
            'total_loss': total_loss / num_batches,
            'epoch': epoch,
            'fold': fold,
            'stage': stage,
            'phase': 'train'
        }
        
        # 各任务损失
        for task_name, loss_sum in task_losses.items():
            epoch_metrics[f'{task_name}_loss'] = loss_sum / num_batches
        
        # 分类任务指标
        for task in ['pt', 'pmt', 'pm']:
            if task_metrics[task]:
                for metric_name in ['accuracy', 'precision', 'recall', 'f1', 'auc', 'avg_confidence']:
                    values = [m[metric_name] for m in task_metrics[task]]
                    epoch_metrics[f'{task}_{metric_name}'] = np.mean(values)
        
        # 生成任务指标
        for task in ['tcr_gen', 'pep_gen']:
            if task_metrics[task]:
                for metric_name in ['token_accuracy', 'perplexity']:
                    values = [m[metric_name] for m in task_metrics[task]]
                    epoch_metrics[f'{task}_{metric_name}'] = np.mean(values)
        
        return epoch_metrics

    def validate_epoch(self, model, val_loader, epoch: int, fold: int, stage: int) -> Dict[str, float]:
        """验证一个epoch"""
        model.eval()
        total_loss = 0
        num_batches = 0
        
        task_losses = {}
        task_metrics = {
            'pt': [], 'pmt': [], 'pm': [],
            'tcr_gen': [], 'pep_gen': []
        }
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                try:
                    outputs = model(batch)
                    
                    # 计算损失（验证阶段用compute_generation_metrics来计算数值损失）
                    val_losses = {}
                    
                    if stage == 1:
                        # 阶段1：分类损失 + 判别器损失
                        for task in ['pt', 'pmt', 'pm']:
                            if f'{task}_logits' in outputs:
                                cls_loss = self.classification_loss(outputs[f'{task}_logits'], batch['labels'])
                                disc_loss = self.discriminator_loss(
                                    outputs[f'{task}_discriminator'].squeeze(),
                                    batch['labels'].float()
                                )
                                val_losses[f'{task}_classification'] = cls_loss
                                val_losses[f'{task}_discriminator'] = disc_loss * 0.1
                                
                    elif stage == 2 and 'tcr_gen_logits' in outputs:
                        # TCR生成损失
                        gen_metrics = self.compute_generation_metrics(
                            outputs['tcr_gen_logits'],
                            outputs['tcr_gen_targets']
                        )
                        val_losses['tcr_generation'] = torch.tensor(gen_metrics['loss'], device=self.device)
                        
                    elif stage == 3 and 'pep_gen_logits' in outputs:
                        # PEP生成损失
                        gen_metrics = self.compute_generation_metrics(
                            outputs['pep_gen_logits'],
                            outputs['pep_gen_targets']
                        )
                        val_losses['pep_generation'] = torch.tensor(gen_metrics['loss'], device=self.device)
                    
                    if not val_losses:
                        continue
                        
                    total_batch_loss = sum(val_losses.values())
                    
                    # 记录损失
                    for task_name, loss in val_losses.items():
                        if task_name not in task_losses:
                            task_losses[task_name] = 0
                        task_losses[task_name] += loss.item()
                    
                    # 计算指标
                    if stage == 1:
                        for task in ['pt', 'pmt', 'pm']:
                            if f'{task}_logits' in outputs:
                                metrics = self.compute_classification_metrics(
                                    outputs[f'{task}_logits'], 
                                    batch['labels'],
                                    outputs[f'{task}_confidence']
                                )
                                task_metrics[task].append(metrics)
                    
                    elif stage == 2 and 'tcr_gen_logits' in outputs:
                        metrics = self.compute_generation_metrics(
                            outputs['tcr_gen_logits'],
                            outputs['tcr_gen_targets']
                        )
                        task_metrics['tcr_gen'].append(metrics)
                    
                    elif stage == 3 and 'pep_gen_logits' in outputs:
                        metrics = self.compute_generation_metrics(
                            outputs['pep_gen_logits'],
                            outputs['pep_gen_targets']
                        )
                        task_metrics['pep_gen'].append(metrics)
                    
                    total_loss += total_batch_loss.item()
                    num_batches += 1
                    
                except Exception as e:
                    self.accelerator.print(f"Error in validation batch {batch_idx}: {str(e)}")
                    continue
        
        if num_batches == 0:
            return {'total_loss': float('inf')}
        
        # 汇总指标
        epoch_metrics = {
            'total_loss': total_loss / num_batches,
            'epoch': epoch,
            'fold': fold,
            'stage': stage,
            'phase': 'validation'
        }
        
        for task_name, loss_sum in task_losses.items():
            epoch_metrics[f'{task_name}_loss'] = loss_sum / num_batches
        
        for task in ['pt', 'pmt', 'pm']:
            if task_metrics[task]:
                for metric_name in ['accuracy', 'precision', 'recall', 'f1', 'auc', 'avg_confidence']:
                    values = [m[metric_name] for m in task_metrics[task]]
                    epoch_metrics[f'{task}_{metric_name}'] = np.mean(values)
        
        for task in ['tcr_gen', 'pep_gen']:
            if task_metrics[task]:
                for metric_name in ['token_accuracy', 'perplexity']:
                    values = [m[metric_name] for m in task_metrics[task]]
                    epoch_metrics[f'{task}_{metric_name}'] = np.mean(values)
        
        return epoch_metrics

    def save_metrics(self, metrics: Dict[str, float]):
        """保存指标到CSV"""
        self.all_metrics.append(metrics)
        
        if self.accelerator.is_main_process:
            df = pd.DataFrame(self.all_metrics)
            csv_path = os.path.join(self.config.get('output_dir', 'outputs'), 'training_metrics.csv')
            df.to_csv(csv_path, index=False)

    def print_metrics(self, metrics: Dict[str, float]):
        """打印指标"""
        phase = metrics['phase']
        epoch = metrics['epoch']
        fold = metrics['fold']
        stage = metrics['stage']
        
        self.accelerator.print(f"\n{phase.title()} Metrics - Stage {stage}, Fold {fold}, Epoch {epoch+1}:")
        self.accelerator.print("-" * 80)
        self.accelerator.print(f"Total Loss: {metrics.get('total_loss', 0):.4f}")
        
        # 分类任务
        for task in ['pt', 'pmt', 'pm']:
            if f'{task}_accuracy' in metrics:
                acc = metrics[f'{task}_accuracy']
                f1 = metrics.get(f'{task}_f1', 0)
                conf = metrics.get(f'{task}_avg_confidence', 0)
                self.accelerator.print(f"{task.upper():<15} - Acc: {acc:.4f}, F1: {f1:.4f}, Conf: {conf:.4f}")
        
        # 生成任务
        for task in ['tcr_gen', 'pep_gen']:
            if f'{task}_token_accuracy' in metrics:
                acc = metrics[f'{task}_token_accuracy']
                ppl = metrics.get(f'{task}_perplexity', 0)
                self.accelerator.print(f"{task.upper():<15} - Token Acc: {acc:.4f}, PPL: {ppl:.4f}")

    def train_stage(self, model, train_loader, val_loader, stage: int, fold: int):
        """训练一个阶段"""
        self.accelerator.print(f"\n=== Training Stage {stage} ===")
        
        # 冻结相应模块
        self.freeze_modules(model, stage)
        
        # 只为可训练参数创建优化器
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        self.accelerator.print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
        
        optimizer = optim.AdamW(
            trainable_params,
            lr=self.config.get(f'stage_{stage}_lr', self.config.get('learning_rate', 1e-4)),
            weight_decay=self.config.get('weight_decay', 0.01)
        )
        
        num_epochs = self.config.get(f'stage_{stage}_epochs', self.config.get('num_epochs', 5))
        num_training_steps = len(train_loader) * num_epochs
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_training_steps)
        
        # 准备优化器和调度器
        optimizer, scheduler = self.accelerator.prepare(optimizer, scheduler)
        
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            # 训练
            train_metrics = self.train_epoch(model, train_loader, optimizer, scheduler, epoch, fold, stage)
            self.save_metrics(train_metrics)
            self.print_metrics(train_metrics)
            
            # 验证
            val_metrics = self.validate_epoch(model, val_loader, epoch, fold, stage)
            self.save_metrics(val_metrics)
            self.print_metrics(val_metrics)
            
            # 保存最佳模型
            if val_metrics['total_loss'] < best_val_loss:
                best_val_loss = val_metrics['total_loss']
                if self.accelerator.is_main_process:
                    save_path = os.path.join(
                        self.config.get('output_dir', 'outputs'),
                        f'best_model_stage_{stage}_fold_{fold}.pt'
                    )
                    torch.save(self.accelerator.unwrap_model(model).state_dict(), save_path)
                    self.accelerator.print(f"Saved best model for stage {stage}, fold {fold} with loss {best_val_loss:.4f}")
        
        return best_val_loss

    def train_fold(self, train_dataset, val_dataset, fold: int):
        """训练一个fold的所有阶段"""
        self.accelerator.print(f"\n=== Fold {fold} ===")
        self.accelerator.print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")
        
        # 创建数据加载器
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.get('batch_size', 128),
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=self.config.get('num_workers', 12),
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=8
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.get('batch_size', 128),
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=self.config.get('num_workers', 12),
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=8
        )
        
        # 创建模型
        model = self.create_model()
        
        # 使用accelerator准备
        model, train_loader, val_loader = self.accelerator.prepare(model, train_loader, val_loader)
        
        fold_results = {}
        
        # 三个训练阶段
        for stage in [1, 2, 3]:
            best_val_loss = self.train_stage(model, train_loader, val_loader, stage, fold)
            fold_results[f'stage_{stage}_best_loss'] = best_val_loss
        
        return fold_results

    def train(self):
        """主训练函数"""
        self.accelerator.print("Starting Multi-Stage Multi-Task Training")
        self.accelerator.print("=" * 50)
        
        # 加载数据
        df = self.data_processor.load_and_process_data()
        df_balanced = self.data_processor.create_balanced_dataset(df, negative_ratio=1.0)
        dataset = self.data_processor.create_five_task_dataset(df_balanced)
        
        # 创建交叉验证分割
        cv_splits = create_cv_splits(df_balanced, n_splits=self.config.get('n_folds', 5))
        
        all_fold_results = []
        
        for fold, (train_indices, val_indices) in enumerate(cv_splits):
            train_data = [dataset[i] for i in train_indices]
            val_data = [dataset[i] for i in val_indices]
            
            train_dataset = MultiTaskDataset(train_data)
            val_dataset = MultiTaskDataset(val_data)
            
            fold_results = self.train_fold(train_dataset, val_dataset, fold)
            fold_results['fold'] = fold
            all_fold_results.append(fold_results)
            
            self.accelerator.print(f"Fold {fold} completed: {fold_results}")
        
        # 最终结果
        self.accelerator.print(f"\nFinal Results Across All Folds:")
        for stage in [1, 2, 3]:
            losses = [result[f'stage_{stage}_best_loss'] for result in all_fold_results]
            self.accelerator.print(f"Stage {stage} - Mean Loss: {np.mean(losses):.4f} ± {np.std(losses):.4f}")
        
        # 保存结果
        if self.accelerator.is_main_process:
            results_path = os.path.join(self.config.get('output_dir', 'outputs'), 'training_results.json')
            with open(results_path, 'w') as f:
                json.dump(all_fold_results, f, indent=2)
            self.accelerator.print(f"Results saved to {results_path}")

def main():
    parser = argparse.ArgumentParser(description='Multi-Stage Multi-Task Immune Training')
    parser.add_argument('--data_path', type=str, default='data.csv', help='Data file path')
    parser.add_argument('--output_dir', type=str, default='outputs', help='Output directory')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=50, help='Number of epochs per stage')
    parser.add_argument('--learning_rate', type=float, default=1e-5, help='Learning rate')
    
    args = parser.parse_args()
    
    # 分阶段配置
    config = {
        'data_path': args.data_path,
        'output_dir': args.output_dir,
        'seed': 42,
        'n_folds': 3,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.learning_rate,
        'weight_decay': 0.01,
        'max_len': 120,
        'd_model': 512,
        'n_encoder_layers': 6,
        'n_decoder_layers': 4,
        'n_heads': 8,
        'dropout': 0.1,
        'gradient_accumulation_steps': 4,
        'mixed_precision': 'bf16',
        'log_interval': 100,
        'num_workers': 12,
        'use_tensorboard': True,
        
        # 分阶段参数
        'stage_1_epochs': args.num_epochs,
        'stage_2_epochs': args.num_epochs,
        'stage_3_epochs': args.num_epochs,
        'stage_1_lr': args.learning_rate,
        'stage_2_lr': args.learning_rate, 
        'stage_3_lr': args.learning_rate,
    }
    
    trainer = MultiStageTrainer(config)
    trainer.train()

if __name__ == "__main__":
    main()