import pickle
from tqdm import tqdm
import time
import torch
import argparse
import os
import numpy as np

from pytorch_lightning import Trainer
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy

from utils import GLUEDataModule, windowed_mean
from models.llm_module import GLUETransformer

def save_pickle(data, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)
    with open(filepath, 'wb') as file:
        pickle.dump(data, file)
    print(f"Data saved to: {filepath}")

def parse_arguments():
    parser = argparse.ArgumentParser(description='Multi-GPU Fine-tuning Argument Parser')

    # 基础参数
    parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
    parser.add_argument('--samplesize', type=int, default=1024, help='Training data sample size')
    parser.add_argument('--samplesize_validation', type=int, default=128, help='Validation data sample size')
    parser.add_argument('--model_name', type=str, default='distilbert-base-cased', help='Name of the pre-trained model')
    parser.add_argument('--task', type=str, default='mnli', help='Task for model training')
    parser.add_argument('--full_parameter', action='store_true', help='True for full parameter fine-tuning')
    parser.add_argument('--algorithm', type=str, default='FO-AdamW', help='Algorithm to use ("FO-SGD", "FO-Adam", "FO-AdamW", "FO-Adagrad")')
    
    # 批次大小参数
    parser.add_argument('--batchsize', type=int, default=128, help='Total batch size across all GPUs')
    parser.add_argument('--batchsize_limit', type=int, default=32, help='Batch size per GPU')
    parser.add_argument('--max_seq_length', type=int, default=256, help='Max sequence length for inputs')

    # 学习率和优化参数
    parser.add_argument('--lr', type=float, default=2e-3, help='Learning rate')
    parser.add_argument('--anneal', type=float, default=1.5, help='Annealing parameter')
    
    # 多GPU参数
    parser.add_argument('--devices', type=str, default='auto', help='GPU devices (e.g., "0,1,2,3" or "auto")')
    parser.add_argument('--strategy', type=str, default='ddp', help='Distributed strategy ("ddp", "ddp_spawn", "deepspeed")')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers per GPU')
    
    # 其他参数
    parser.add_argument('--results', type=str, default='results_parallel', help='Name of folder to store results')
    parser.add_argument('--soft_prompt', action='store_true', help='True for using soft prompt')
    parser.add_argument('--half_precision', action='store_true', help='Using half-precision fine-tuning')
    parser.add_argument('--gradient_clip_val', type=float, default=1.0, help='Gradient clipping value')
    parser.add_argument('--accumulate_grad_batches', type=int, default=1, help='Number of batches to accumulate gradients')
    
    # ZO算法相关参数 (目前仅支持FO算法)
    parser.add_argument('--q', type=int, default=2, help='q parameter used only for ZO-SVRG')
    parser.add_argument('--lr_mezosvrg_mb', type=float, default=1e-6, help='Mini-batch learning rate for MeZO-SVRG')
    parser.add_argument('--perturbation_scale', type=float, default=1e-3, help='Perturbation scale for SPSA estimators')
    parser.add_argument('--alpha', type=float, default=0.1, help='Alpha for ZO-FO-SVRG')
    
    args = parser.parse_args()
    return args

def setup_devices(devices_str):
    """设置GPU设备"""
    if devices_str == "auto":
        return "auto"
    else:
        # 解析设备字符串，如 "0,1,2,3"
        device_list = [int(d.strip()) for d in devices_str.split(",")]
        return device_list

def setup_strategy(strategy_name):
    """设置分布式策略"""
    if strategy_name == "ddp":
        return DDPStrategy(find_unused_parameters=False)
    elif strategy_name == "ddp_spawn":
        return DDPStrategy(find_unused_parameters=False, use_distributed_sampler=True)
    elif strategy_name == "deepspeed":
        return DeepSpeedStrategy()
    else:
        return "auto"

def finetune_FO_parallel(devices, strategy, max_seq_length, model_name, task, samplesize, 
                         samplesize_validation, batchsize, batchsize_limit, lr, full_parameter, 
                         results_folder, soft_prompt, half_precision=False, num_workers=4,
                         gradient_clip_val=1.0, accumulate_grad_batches=1):
    """
    多GPU并行微调FO算法
    """
    print(f"Setting up multi-GPU training with devices: {devices}")
    print(f"Strategy: {strategy}")
    
    # 初始化数据模块
    dm = GLUEDataModule(
        model_name_or_path=model_name,
        task_name=task,
        max_seq_length=max_seq_length,
        sample_size=samplesize,
        train_batch_size=batchsize_limit,  # 每个GPU的批次大小
        validation_sample_size=samplesize_validation,
        eval_batch_size=batchsize_limit,
        soft_prompt=soft_prompt,
        num_workers=num_workers
    )
    dm.setup("fit")
    
    # 解析算法类型
    use_SGD = 'SGD' in algorithm
    use_Adam = 'Adam' in algorithm and 'AdamW' not in algorithm
    use_AdamW = 'AdamW' in algorithm
    use_Adagrad = 'Adagrad' in algorithm
    
    # 初始化模型
    transformer = GLUETransformer(
        model_name_or_path=model_name,
        num_labels=dm.num_labels,
        eval_splits=dm.eval_splits,
        task_name=dm.task_name,
        learning_rate=lr,
        full_parameter=full_parameter,
        soft_prompt=soft_prompt,
        use_SGD=use_SGD,
        use_Adam=use_Adam,
        use_AdamW=use_AdamW,
        use_Adagrad=use_Adagrad
    )
    
    # 配置Trainer
    trainer_kwargs = {
        'max_epochs': epochs,
        'accelerator': 'auto',
        'devices': devices,
        'strategy': strategy,
        'accumulate_grad_batches': accumulate_grad_batches,
        'gradient_clip_val': gradient_clip_val,
        'enable_progress_bar': True,
        'log_every_n_steps': 10,
    }
    
    if half_precision:
        trainer_kwargs['precision'] = 'bf16-mixed'
    
    trainer = Trainer(**trainer_kwargs)
    
    # 开始训练
    print("Starting multi-GPU training...")
    start_time = time.time()
    trainer.fit(transformer, datamodule=dm)
    end_time = time.time()
    total_training_time = end_time - start_time
    
    # 保存结果
    dict_results = {
        'Model': model_name,
        'Task': task,
        'BS': batchsize,
        'BS_per_GPU': batchsize_limit,
        'LR': lr,
        'Algorithm': algorithm,
        'Devices': devices,
        'Strategy': str(strategy),
        'Tr_Loss': transformer.tr_loss,
        'Time': transformer.time,
        'Query': transformer.query,
        'Grad_Norm': transformer.grad_norm,
        'Overall_Tr_Time': total_training_time,
        'Val_Loss': transformer.val_loss_ls,
        'Val_Acc': transformer.val_acc,
        'Memory': transformer.memory_usage,
        'Num_GPUs': len(devices) if isinstance(devices, list) else 'auto'
    }
    
    # 生成文件名
    if 'facebook' in model_name:
        model_name = model_name.replace('facebook/', "")
    
    gpu_info = f"GPU{len(devices) if isinstance(devices, list) else 'auto'}"
    file_name = f'{model_name}_{task}_{algorithm}_{gpu_info}_lr{str(lr)}_bs{str(batchsize)}_samplesize{str(samplesize)}_fullparam{str(full_parameter)}.pickle'
    
    save_pickle(dict_results, results_folder, file_name)
    
    # 打印统计信息
    print('Finished Task ' + task + ' with full parameter being ' + str(full_parameter))
    print('-----------------Statistics-----------------')
    window_size_tr = int(np.ceil(len(transformer.tr_loss) / epochs))
    arr_tr_loss = windowed_mean(transformer.tr_loss, window_size_tr)
    print('Best Training Loss: ', np.nanmin(arr_tr_loss))
    window_size_val = 2
    arr_val_acc = windowed_mean(transformer.val_acc, window_size_val)
    print('Best Validation Accuracy: ', np.max(arr_val_acc))
    print('Peak Memory Usage (GB): ', np.max(transformer.memory_usage))
    print('Total queries: ', np.sum(transformer.query))
    print(f'Total training time: {total_training_time:.2f} seconds')

def print_experiment_info(args):
    """打印实验配置信息"""
    print('=' * 60)
    print('MULTI-GPU FINE-TUNING EXPERIMENT')
    print('=' * 60)
    print(f'Model: {args.model_name}')
    print(f'Task: {args.task}')
    print(f'Algorithm: {args.algorithm}')
    print(f'Devices: {args.devices}')
    print(f'Strategy: {args.strategy}')
    print(f'Epochs: {args.epochs}')
    print(f'Total Batch Size: {args.batchsize}')
    print(f'Batch Size per GPU: {args.batchsize_limit}')
    print(f'Learning Rate: {args.lr}')
    print(f'Sample Size: {args.samplesize}')
    print(f'Max Seq Length: {args.max_seq_length}')
    print(f'Full Parameter: {args.full_parameter}')
    print(f'Half Precision: {args.half_precision}')
    print(f'Results Folder: {args.results}')
    print('=' * 60)

if __name__ == "__main__":
    args = parse_arguments()
    
    # 全局变量设置
    epochs = args.epochs
    algorithm = args.algorithm
    
    # 设置设备和策略
    devices = setup_devices(args.devices)
    strategy = setup_strategy(args.strategy)
    
    # 打印实验信息
    print_experiment_info(args)
    
    # 检查算法支持
    if not algorithm.startswith('FO'):
        print(f"Warning: {algorithm} is not yet supported for multi-GPU training.")
        print("Currently only FO algorithms are supported for multi-GPU.")
        print("Falling back to single GPU training...")
        # 这里可以添加单GPU的回退逻辑
        exit(1)
    
    # 运行多GPU训练
    try:
        finetune_FO_parallel(
            devices=devices,
            strategy=strategy,
            max_seq_length=args.max_seq_length,
            model_name=args.model_name,
            task=args.task,
            samplesize=args.samplesize,
            samplesize_validation=args.samplesize_validation,
            batchsize=args.batchsize,
            batchsize_limit=args.batchsize_limit,
            lr=args.lr,
            full_parameter=args.full_parameter,
            results_folder=args.results,
            soft_prompt=args.soft_prompt,
            half_precision=args.half_precision,
            num_workers=args.num_workers,
            gradient_clip_val=args.gradient_clip_val,
            accumulate_grad_batches=args.accumulate_grad_batches
        )
        print("Multi-GPU training completed successfully!")
        
    except Exception as e:
        print(f"Error during multi-GPU training: {e}")
        print("Please check your GPU configuration and try again.")
        raise 