'''
MIT License
Copyright (c) 2023 UNITES Lab
This file is modified from (https://github.com/UNITES-Lab/MC-SMoE)
'''
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, List, Union, Dict
import re
from omegaconf import OmegaConf as om
import time
import torch
import wandb
from datasets import load_dataset
from evaluate import load
from fire import Fire
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    OlmoeConfig,
    Qwen2MoeConfig,
    PretrainedConfig,
    AutoTokenizer,
    AutoModelForCausalLM
)

from mergemoe.data import (
    WinoGrandePreProcessor,
    HellaSwagPreProcessor,
    PIQAPreProcessor, 
    ARCPreProcessor,
    COPAPreProcessor,
    SQUADPreProcessor,
    MRPCPreProcessor, 
    build_index_for_dataset,
    tokenize_batch,
    TASK_MAPPING_DATASET_ARGUMENTS,
    DataCollatorForLanguageModeling,

)

from mergemoe.merging import (
    ExpertsGrouperForOLMoE,
    ExpertsGrouperForDeepseek,
    ExpertsGrouperForQwen,
    OLMoE_merge_by_groups_with_usage_frequency_weighting,
    OLMoE_merge_by_groups_with_ACT,
    OLMoE_merge_by_groups,
    OLMoE_merge_by_groups_within_and_across_models,
    OLMoE_merge_by_Linear,
    Deepseek_merge_by_groups_with_usage_frequency_weighting,
    Deepseek_merge_by_groups_with_ACT,
    Deepseek_merge_by_groups,
    Deepseek_merge_by_groups_within_and_across_models,
    Qwen_merge_by_groups_with_usage_frequency_weighting,
    Qwen_merge_by_groups_with_ACT,
    Qwen_merge_by_groups,
    Qwen_merge_by_groups_within_and_across_models,
)
from mergemoe.utils.sparsity import compute_weight_stable_rank
from mergemoe.utils.configuration_deepseek import DeepseekConfig

def sanitize_merging_layers(layers: Union[str, List, int]):
    """
    Sanitize and format merging layers input into a list of integers.
    
    Args:
        layers: Input can be string (comma-separated), list, or single integer
        
    Returns:
        List of integers representing layer indices
    """
    if layers is None:
        layers = list(range(1, 12, 2))
    elif isinstance(layers, str) and len(layers) > 0:
        layers = [int(x) for x in layers.split(",")]
    elif isinstance(layers, str) and len(layers) == 0:
        layers = []
    elif isinstance(layers, int):
        layers = [layers]
    return layers


def save_stable_rank_to(
        core_experts: Dict[str, List[int]], state_dict: Dict[str, torch.Tensor], save_dir: str, save_name: str
):
    wi_stable_rank_dict = {key: [] for key in core_experts.keys()}
    wo_stable_rank_dict = {key: [] for key in core_experts.keys()}
    for mlp_name in tqdm(core_experts, desc="Computing stable rank"):
        for i, core_idx in enumerate(core_experts[mlp_name]):
            name = f"{mlp_name}.experts.expert_{core_idx}.wi.weight"
            stable_rank = compute_weight_stable_rank(state_dict[name].float())
            wi_stable_rank_dict[mlp_name].append(stable_rank)
            name = f"{mlp_name}.experts.expert_{core_idx}.wo.weight"
            stable_rank = compute_weight_stable_rank(state_dict[name].float())
            wo_stable_rank_dict[mlp_name].append(stable_rank)
        wi_stable_rank_dict[mlp_name] = torch.tensor(wi_stable_rank_dict[mlp_name])
        wo_stable_rank_dict[mlp_name] = torch.tensor(wo_stable_rank_dict[mlp_name])
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    torch.save(wi_stable_rank_dict, os.path.join(save_dir, f"{save_name}.wi.pt"))
    torch.save(wo_stable_rank_dict, os.path.join(save_dir, f"{save_name}.wo.pt"))

def prepare_processor(task, train = True):
    """Get appropriate data processor for specified task.
    
    Args:
        task: Name of the benchmark task
        train: Flag indicating training mode (not used in current implementation)
        
    Returns:
        Processor class for the specified task
    """
    dataset_mapping = {
        "winogrande": WinoGrandePreProcessor,
        "hellaswag": HellaSwagPreProcessor,
        "piqa": PIQAPreProcessor,
        "arc_easy": ARCPreProcessor,
        "arc_challenge": ARCPreProcessor,
        "copa": COPAPreProcessor,
        "squad": SQUADPreProcessor,
        "mrpc": MRPCPreProcessor, 
    }
    return dataset_mapping[task](benchmark=task)

def load_grouper(
        model_type: str,
        config: Union[OlmoeConfig, Qwen2MoeConfig, DeepseekConfig, PretrainedConfig],
        similarity_fn: str,
        similarity_base: str
):
    if model_type == "olmoe":
        grouper = ExpertsGrouperForOLMoE(
            config=config,
            similarity_fn=similarity_fn,
            similarity_base=similarity_base
        )
    elif model_type == "deepseek":
        grouper = ExpertsGrouperForDeepseek(
            config=config,
            similarity_fn=similarity_fn,
            similarity_base=similarity_base
        )
    elif model_type == "qwen":
        grouper = ExpertsGrouperForQwen(
            config=config,
            similarity_fn=similarity_fn,
            similarity_base=similarity_base
        )

    else:
        raise ValueError("Unavailable model_type")
    
    return grouper


def merge_model_with_strategy(
        model,
        model_type,
        merging_strategy,
        grouper,
        merging_layers,
        dataloader_for_merging,
):
    if model_type == "olmoe":
        if merging_strategy == "mcsmoe":
            model = OLMoE_merge_by_groups_with_usage_frequency_weighting(
                model=model,
                grouper=grouper,
                strategy="normal",
                merging_layers=merging_layers,
                permute=False,
                within_and_across_models=False,
            )
        elif merging_strategy == "average":
            model = OLMoE_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=False,
            )
        elif merging_strategy == "repair":
            model = OLMoE_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=True,
                permute_strategy="activation-matching",
                dataloader=dataloader_for_merging,
            )
        elif merging_strategy == "git-rebasin":
            model = OLMoE_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=True,
                permute_strategy="weight-matching",
            )
        elif merging_strategy == "zipit":
            model = OLMoE_merge_by_groups_within_and_across_models(
                model=model,
                grouper=grouper,
                dataloader=dataloader_for_merging,
                merging_layers=merging_layers,
                dominant_alone=False,
                usage_weighted=False,
            )
        elif merging_strategy == "linear": 
            model = OLMoE_merge_by_Linear(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                batch=next(iter(dataloader_for_merging))
            )
        else:
            model = OLMoE_merge_by_groups_with_ACT(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                batch=next(iter(dataloader_for_merging))
            )
    if model_type == "deepseek":
        if merging_strategy == "mcsmoe":
            model = Deepseek_merge_by_groups_with_usage_frequency_weighting(
                model=model,
                grouper=grouper,
                strategy="normal",
                merging_layers=merging_layers,
                permute=False,
                within_and_across_models=False,
            )
        elif merging_strategy == "average":
            model = Deepseek_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=False,
            )
        elif merging_strategy == "repair":
            model = Deepseek_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=True,
                permute_strategy="activation-matching",
                dataloader=dataloader_for_merging,
            )
        elif merging_strategy == "git-rebasin":
            model = Deepseek_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=True,
                permute_strategy="weight-matching",
            )
        elif merging_strategy == "zipit":
            model = Deepseek_merge_by_groups_within_and_across_models(
                model=model,
                grouper=grouper,
                dataloader=dataloader_for_merging,
                merging_layers=merging_layers,
                dominant_alone=False,
                usage_weighted=False,
            )
        else:
            model = Deepseek_merge_by_groups_with_ACT(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                batch=next(iter(dataloader_for_merging))
            )   
    if model_type == "qwen":
        if merging_strategy == "mcsmoe":
            model = Qwen_merge_by_groups_with_usage_frequency_weighting(
                model=model,
                grouper=grouper,
                strategy="normal",
                merging_layers=merging_layers,
                permute=False,
                within_and_across_models=False,
            )
        elif merging_strategy == "average":
            model = Qwen_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=False,
            )
        elif merging_strategy == "repair":
            model = Qwen_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=True,
                permute_strategy="activation-matching",
                dataloader=dataloader_for_merging,
            )
        elif merging_strategy == "git-rebasin":
            model = Qwen_merge_by_groups(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                permute=True,
                permute_strategy="weight-matching",
            )
        elif merging_strategy == "zipit":
            model = Qwen_merge_by_groups_within_and_across_models(
                model=model,
                grouper=grouper,
                dataloader=dataloader_for_merging,
                merging_layers=merging_layers,
                dominant_alone=False,
                usage_weighted=False,
            )
        else:
            model = Qwen_merge_by_groups_with_ACT(
                model=model,
                grouper=grouper,
                merging_layers=merging_layers,
                batch=next(iter(dataloader_for_merging))
            )     
    return model




def merge_model(
        output_dir: Optional[str] = None,
        checkpoint: Optional[str] = None,
        model_type: Optional[str] = None,
        task: Optional[str] = "sst2",
        merging_strategy: Optional[str] = "ours",
        num_samples_for_merging: Optional[int] = 32,
        num_groups: Optional[int] = 16,
        group_capacity: Optional[int] = None,
        merging_layers: Optional[Union[str, List, int]] = None,
):
    """Main function for merging experts in MoE models.
    
    Args:
        output_dir: Output directory for saving merged model
        checkpoint: Path to pre-trained model checkpoint
        model_type: Type of MoE architecture (olmoe/deepseek/qwen)
        task: Benchmark task name
        merging_strategy: Strategy for expert merging
        num_samples_for_merging: Number of samples to use for merging
        num_groups: Target number of expert groups
        group_capacity: Maximum experts per group
        merging_layers: Layers to apply merging on
    """
    if output_dir is None:
        raise ValueError("output_dir must be specified")
    else:
        output_dir = os.path.join(output_dir, merging_strategy)
        os.makedirs(output_dir, exist_ok=True)
    if checkpoint is None:
        raise ValueError("checkpoint must be specified")
    merging_layers = sanitize_merging_layers(merging_layers)

    model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, torch_dtype=torch.bfloat16)

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
        keys_to_ignore=["answer_idx", "choice_idx", "idx"]
    )

    raw_dataset = load_dataset(*TASK_MAPPING_DATASET_ARGUMENTS[task])

    train_dataset = build_index_for_dataset(raw_dataset["train"])
    processor = prepare_processor(task=task)
    train_dataset = train_dataset.map(
        processor,
        batched=True,
        num_proc=8,
        remove_columns=train_dataset.column_names
    )
    
    tokenized_train_dataset = train_dataset.map(
        lambda x: tokenize_batch(tokenizer=tokenizer, batch=x),
        num_proc=8,
        batched=True,
        remove_columns=train_dataset.column_names,
        load_from_cache_file=False
    )

    # random select a subset of training dataset for merging
    dataset_for_merging = tokenized_train_dataset.shuffle(seed=2333).select(range(num_samples_for_merging))

    dataloader_for_merging = DataLoader(
        dataset_for_merging,
        shuffle=False,
        collate_fn=data_collator,
        batch_size=num_samples_for_merging,
        num_workers=3
    )
    
    print(f"Merging experts by usage weighted averaging with {merging_strategy} usage")
    print(f"Number of groups: {num_groups}")
    print(f"Merge layers: {merging_layers}")
    print(f"Number of parameters before merging: {model.num_parameters()}")
    
    '''
    Initialize the merging strategy. 
    Note: "linear" is not an actual merging strategy, but an ablation on the compression errors.
           And only OlMoE support it.
    '''
    if merging_strategy == "linear":
        grouper = load_grouper(
            model_type=model_type,
            config=model.config,
            similarity_fn="l2",
            similarity_base="gate-up-weight"
        )
    elif merging_strategy != "ours":
        grouper = load_grouper(
            model_type=model_type,
            config=model.config,
            similarity_fn="cosine",
            similarity_base="router-logits"
        )
    else:
        grouper = load_grouper(
            model_type=model_type,
            config=model.config,
            similarity_fn="l2",
            similarity_base="gate-up-weight"
        )
    start_time = time.time()
    grouper.set_avg_num_merged_experts(num_groups)

    grouper.compute_all_similarities(
        model=model,
        batch=next(iter(dataloader_for_merging)),
        merging_layers = merging_layers,
    )

    grouper.compute_all_usages(
        model=model,
        batch=next(iter(dataloader_for_merging)),
        merging_layers = merging_layers,
    )
    core_experts = grouper.group_experts_into_clusters_by_routing_guided_globally(
        average_num_groups=num_groups,
        merging_layers=merging_layers,
        layer_group_capacity=group_capacity,
    )

    model = merge_model_with_strategy(
        model=model,
        model_type=model_type,
        merging_strategy=merging_strategy,
        grouper=grouper,
        merging_layers=merging_layers,
        dataloader_for_merging=dataloader_for_merging,
    )
    print("merging time:", time.time() - start_time)
    print(f"Number of parameters after merging: {model.num_parameters()}")
    print(f"Saving merged model to {output_dir}")
    model.save_pretrained(output_dir, safe_serialization=False)
    print("saving time:", time.time() - start_time)

if __name__ == "__main__":
    Fire(merge_model)
