import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5"

from utils.common_utils import set_random_seed
import argparse
from mine import get_layer_params
from utils.compress_utils import (
    calculate_group_importance_scores,
    allocate_group_compression_params,
    convert_to_compressor_format
)
from expert_group import expert_grouping
from compressor.deepseek_compressor import DeepSeekCompressor
from utils.evaluate_utils import run_lm_eval, ppl_eval_sharing
from accelerate import Accelerator


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run compression on DeepSeek MoE models.")
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed for reproducibility')
    parser.add_argument('--model_name', type=str, default='deepseek-ai/deepseek-moe-16b-base',
                        help='model name or path to the DeepSeek MoE model')
    parser.add_argument('--compression_ratio', type=float, default=0.6,
                        help='Target compression ratio for the model')
    parser.add_argument('--use_auto_weights', action='store_true', default=False,
                        help='Use automatic PCA weights for feature importance')
    parser.add_argument('--auto_method', type=str, default='pca',
                        choices=['pca', 'variance',
                                 'mutual_info', 'correlation'],
                        help='Method for automatic weight determination')
    parser.add_argument('--smoothness_factor', type=float, default=0.05,
                        help='Smoothness factor for group compression params allocation')
    parser.add_argument('--num_samples', type=int, default=256,
                        help='Number of samples for compression ratio calculation')
    parser.add_argument('--mine_epochs', type=int, default=150,
                        help='Number of epochs for MINE training')
    
    parser.add_argument('--evaluate', action='store_true', default=True,
                        help='Evaluate the perplexity of the compressed model')
    parser.add_argument('--dataset_name', type=str, default='wikitext',
                        help='Name of the dataset to evaluate on.')
    parser.add_argument('--dataset_config', type=str,
                        default='wikitext-2-raw-v1', help='Configuration of the dataset.')
    parser.add_argument('--max_length', type=int, default=2048,
                        help='Maximum sequence length for sliding window.')
    parser.add_argument('--stride', type=int, default=512,
                        help='Stride for sliding window.')
    parser.add_argument('--batch_size', type=int, default=4,
                        help='Batch size for evaluation.')


    return parser.parse_args()


def main():
    args = parse_args()
    set_random_seed(args.seed)

    print("OAGE-MoE Compression Pipeline Started...")

    # 步骤1: 专家分组
    print("Step 1: Expert grouping...")
    feature_df, layer2group2expert= expert_grouping(
        model_name=args.model_name)
    print(layer2group2expert)
    print(feature_df.head())

    # 步骤2: 获取层间压缩率
    print("Step 2: Layer compression ratios...")
    layer_moe_params, layer_indices_to_analyze = get_layer_params(
        model_name=args.model_name,
        target_compression_ratio=args.compression_ratio,
        num_samples=args.num_samples,
        mine_epochs=args.mine_epochs
    )

    # 步骤3: 计算组重要性得分
    print(
        f"Step 3: Group importance scores ({'auto' if args.use_auto_weights else 'default'} weights)...")
    layer_group_importance = calculate_group_importance_scores(
        feature_df=feature_df,
        layer2group2expert=layer2group2expert,
        importance_weights=None,
        use_auto_weights=args.use_auto_weights,
        auto_method=args.auto_method
    )

    # 步骤4: 分配组参数量
    print("Step 4: Allocate group params...")
    layer_group_params = allocate_group_compression_params(
        layer_group_importance=layer_group_importance,
        target_layer_moe_params=layer_moe_params,
        layer2group2expert=layer2group2expert,
        smoothness_factor=args.smoothness_factor  # 平滑因子
    )

    # 步骤5: 转换为压缩器格式
    print("Step 5: Format conversion...")
    layers_groups_experts = convert_to_compressor_format(
        layer2group2expert=layer2group2expert
    )

    # 步骤6: 执行压缩
    print("Step 6: Model compression...")

    import numpy as np
    # layers_groups_experts = np.load(
    #     'results/deepseek/layers_groups_experts0.4.npy', allow_pickle=True).item()
    # layer_group_params = np.load(
        # 'results/deepseek/layer_group_params0.4.npy', allow_pickle=True).item()
    np.save('results/deepseek/layers_groups_experts0.6.npy', layers_groups_experts, allow_pickle=True)
    np.save('results/deepseek/layer_group_params0.6.npy', layer_group_params, allow_pickle=True)
    
    
    compressor = DeepSeekCompressor(
        layers_expert_groups=layers_groups_experts,
        layer_group_params=layer_group_params,
        compression_ratio=args.compression_ratio,
    )
    
    compressed_model, tokenizer = compressor.compress_all_experts()

    print("Compression completed successfully!")
    
    
    if args.evaluate:
        print("Step 7: Evaluating perplexity...")
        result_str = ppl_eval_sharing(compressed_model, tokenizer, experiment_name=f"deepseek", datasets=[
                                      'wikitext2', 'ptb', 'c4'], params_only=False, batch_size=4)
        with open(f"./results/ppl_eval_sharing.txt", "w") as f:
            f.write(result_str)
            
        run_lm_eval(compressed_model, tokenizer, batch_size=16, task_names=["openbookqa", "arc_easy", "winogrande", "hellaswag",
                                                                 "arc_challenge", "piqa", "mathqa"], output_dir='./results')
        

if __name__ == "__main__":
    main()
