import os
import gc
import sys
import time
import pickle
from typing import Optional, List

import logging
import torch
from fire import Fire
from transformers import AutoTokenizer
from transformers.models.deepseek_v2.modeling_deepseek_v2 import DeepseekV2ForCausalLM

from eval import evaluate_fewshot, get_calib_dataloder
from .merging.grouping_deepseek import (
    ExpertsGrouperForDeepseekMoE,
    FineGrainedExpertsGrouperForDeepseekMoE,
    merge_by_groups_with_usage_weighted,
    merge_by_groups_within_and_across_models,
    merge_by_feature_selection,
    merge_deepseek_by_norm_drop_fre,
)

logger = logging.getLogger(__name__)




class TAMPArgs:
    def __init__(
        self,
        task,
        num_average_groups: Optional[List[int]] = None,
        model_name: Optional[str] = "deepseek-ai/deepseek-moe-16b-base",
        dominant: Optional[str] = "knowledge",
        similarity_base: Optional[str] = "router-logits",
        merge: Optional[str] = "zipit",
        mode: Optional[str] = "normal",
        calib_set: Optional[str] = "c4",
        n_sentences: Optional[int] = 32,
        train_batch_size: Optional[int] = 4,
        eval_batch_size: Optional[int] = 32,
        seed: Optional[int] = 42,
        partition: Optional[int] = 1,
        start_layer: Optional[int] = 1,
        output_path: Optional[str] = None,
        result_path: Optional[str] = None,
        model_path: Optional[str] = None,
        group_limit: Optional[int] = 4,
        data_limit: Optional[int] = 50000,
        num_fewshot: Optional[int] = 0,
        try_oracle: Optional[bool] = False,
        random_start_center: Optional[bool] = False,
        weight: Optional[str] = None,
        cluster: Optional[str] = "kmeans",
        linkage: Optional[str] = "ward",
        hierarchical_stopping_metric: Optional[str] = "silhouette",
        overlap_metric: Optional[str] = "kl-divergence",
        group: Optional[str] = "expert",
        dynamic_group: Optional[bool] = False,
        opt_path: Optional[str] = None,
    ):
        self.task = task
        self.num_average_groups = num_average_groups
        self.model_name = model_name
        self.dominant = dominant
        self.similarity_base = similarity_base
        self.merge = merge
        self.mode = mode
        self.calib_set = calib_set
        self.n_sentences = n_sentences
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.seed = seed
        self.partition = partition
        self.start_layer = start_layer
        self.output_path = output_path
        self.result_path = result_path
        self.model_path = model_path
        self.group_limit = group_limit
        self.data_limit = data_limit
        self.num_fewshot = num_fewshot
        self.try_oracle = try_oracle
        self.random_start_center = random_start_center
        self.weight = weight
        self.cluster = cluster
        self.linkage = linkage
        self.hierarchical_stopping_metric = hierarchical_stopping_metric
        self.overlap_metric = overlap_metric
        self.group = group
        self.dynamic_group = dynamic_group
        self.opt_path = opt_path
def get_dataloader(args, tokenizer):
    return get_calib_dataloder(
        dataset=args.calib_set,
        tokenizer=tokenizer,
        max_block_size=2048,
        n_blocks_for_stat=args.n_sentences, 
        batch_size=args.train_batch_size,
        num_workers=4,
    )

def get_grouper(args, config):
    if args.group == "module":
        return FineGrainedExpertsGrouperForDeepseekMoE(
            config=config,
            similarity_base=args.similarity_base,
            start_layer=args.start_layer,
            cluster=args.cluster,
            linkage=args.linkage,
            hierarchical_stopping_metric=args.hierarchical_stopping_metric,
            opt_path=args.opt_path,
        )
    else:
        return ExpertsGrouperForDeepseekMoE(
                config=config,
                similarity_base=args.similarity_base,
                start_layer=args.start_layer,
                group_limit=args.group_limit,
                data_limit=args.data_limit,
                random_start_center=args.random_start_center,
                cluster=args.cluster,
                linkage=args.linkage,
                hierarchical_stopping_metric=args.hierarchical_stopping_metric,
                overlap_metric=args.overlap_metric,
                dynamic_group=args.dynamic_group,
                opt_path=args.opt_path,
            )

def evaluation(args, model, tokenizer):
    result_dir = args.result_path.split("/")[:-1]
    result_dir = "/".join(result_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    if isinstance(args.task, str):
        evaluate_fewshot(
            model, tokenizer=tokenizer, task=args.task, num_fewshot=args.num_fewshot, output_path=args.result_path, log=True
        )
    else:
        
        
        
        
        
        
        for t in args.task:
            evaluate_fewshot(
                model, tokenizer=tokenizer, task=t, num_fewshot=args.num_fewshot, eval_batch_size=args.eval_batch_size, output_path=args.result_path, log=True
            )

def print_usage_frequency(usage_dict):
    for k in usage_dict:
        for num in usage_dict[k]:
            print(round(num.item(), 4), end=',')
        print()


def run_deepseek_tamp(
        task: str,
        num_average_groups: Optional[List[int]] = None,
        model_name: Optional[str] = "deepseek-ai/deepseek-moe-16b-base",
        dominant: Optional[str] = "no", 
        group: Optional[str] = "expert", 
        similarity_base: Optional[str] = "router-logits", 
        merge: Optional[str] = "norm_drop_freq", 
        mode: Optional[str] = "normal", 
        n_sentences: Optional[int] = 32,
        calib_set: Optional[str] = "c4",
        train_batch_size: Optional[int] = 16,
        eval_batch_size: Optional[int] = 32,
        seed: Optional[int] = 42,
        partition: Optional[int] = 1,
        start_layer: Optional[int] = 0,
        output_path: Optional[str] = None,
        result_path: Optional[str] = None,
        model_path: Optional[str] = None,
        group_limit: Optional[int] = 4,
        data_limit: Optional[int] = 1000000,
        num_fewshot: Optional[int] = 0,
        try_oracle: Optional[bool] = False,
        random_start_center: Optional[bool] = False,
        weight: Optional[str] = None,
        cluster: Optional[str] = "kmeans",
        linkage: Optional[str] = "ward",
        hierarchical_stopping_metric: Optional[str] = "silhouette",
        ingredient: Optional[str] = "act", 
        overlap_metric: Optional[str] = "cosine", 
        dynamic_group: Optional[bool] = False,
        opt_path: Optional[str] = None,

        
        model = None,
        tokenizer = None,
):
    logging.info(f"Merge model {model_name} with {num_average_groups} group, {dominant} dominant + {similarity_base} grouping + {merge} {mode} merge with ingredient {ingredient}, evaluate on {task}")
    logging.info(f"Oracle: {try_oracle}, weight: {weight}")
    logging.info(f"Cluster: {cluster}, linkage: {linkage}, hierarchical_stopping_metric: {hierarchical_stopping_metric}, overlap_metric: {overlap_metric}")
    logging.info(f"Group: {group}, dynamic_group: {dynamic_group}, opt_path: {opt_path}.")
    
    
    args = TAMPArgs(
        task=task,
        num_average_groups=num_average_groups,
        model_name=model_name,
        dominant=dominant,
        similarity_base=similarity_base,
        merge=merge,
        mode=mode,
        calib_set=calib_set,
        n_sentences=n_sentences,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        seed=seed,
        partition=partition,
        start_layer=start_layer,
        output_path=output_path,
        result_path=result_path,
        model_path=model_path,
        group_limit=group_limit,
        data_limit=data_limit,
        num_fewshot=num_fewshot,
        try_oracle=try_oracle,
        random_start_center=random_start_center,
        weight=weight,
        cluster=cluster,
        linkage=linkage,
        hierarchical_stopping_metric=hierarchical_stopping_metric,
        overlap_metric=overlap_metric,
        group=group,
        dynamic_group=dynamic_group,
        opt_path=opt_path,
    )
    torch.manual_seed(args.seed)

    model.eval()
    dataloader_for_merging = get_dataloader(args, tokenizer)
    grouper = get_grouper(args, model.config)

    
    
    
    
    
    logging.info("Number of parameters before merging: %sB", model.num_parameters() / 1e9)
    logging.info("Merging into average %s groups...", num_average_groups)
    group_st = time.time()
    if merge == "freq" or dominant == "frequency":
        grouper.compute_all_usages(model, dataloader_for_merging)
        print_usage_frequency(grouper._usage_frequency_state_dict)
    
    
    
    

    
    dom_experts = None
    if merge == "fsm" or merge == "no":
        pass
    elif dominant == "random":
        grouper.group_experts_randomly(num_groups=num_average_groups)
        dom_experts = None
    elif dominant == "frequency":
        if similarity_base != "no":
            grouper.compute_all_similarities(model, dataloader_for_merging)
        dom_experts = grouper.group_experts_globally_from_dominant_experts(
            num_average_groups=num_average_groups, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
        )
    elif dominant == "routing-score":
        grouper.compute_all_usages(model, dataloader_for_merging, mode="routing-score")
        print_usage_frequency(grouper._usage_frequency_state_dict)
        dom_experts = grouper.group_experts_globally_from_dominant_experts(
            num_average_groups=num_average_groups, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
        )
    elif dominant == "knowledge":
        grouper.compute_all_similarities(model, dataloader_for_merging)
        model = grouper.all_in_one_knowledge_dominant(
            model=model, 
            dataloader=dataloader_for_merging, 
            merge=merge,
            mode=mode,
            num_groups=num_average_groups,
        )
        dom_experts = grouper.core_experts
    elif dominant == "no":
        
        dom_experts = grouper.cluster_experts(model=model, dataloader=dataloader_for_merging)
    else:
        raise ValueError(f"Unknown dominant: {dominant}")   

    
    if group == "module":
        model = grouper.merge_weighted(model, merge, core_experts=dom_experts)
    elif dominant != "knowledge":
        if merge == "no":
            pass
        elif merge == "freq":
            model = merge_by_groups_with_usage_weighted(
                model, grouper=grouper, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
            )
        elif merge == "fsm":
            model, dom_experts = merge_by_feature_selection(
                model, grouper=grouper, dataloader=dataloader_for_merging, num_groups=num_average_groups, mode=mode
            )
        elif merge == "norm_drop_fre":
            
            model = merge_deepseek_by_norm_drop_fre(
                model,
                grouper=grouper,
            )
        else:
            model = merge_by_groups_within_and_across_models(
                deepseek_model=model,
                grouper=grouper,
                dataloader=dataloader_for_merging,
                merge=merge,
                mode=mode,
                partition=partition,
                core_experts=dom_experts,
                dominant_alone=False,
                usage_weighted=False,
                ingredient=ingredient,
            )
        
    print(f"Merging time: {time.time() - group_st:.2f} seconds")

    
    if merge != "no":
        print(f"========= Grouping results ========= ")
        if group == "module":
            for name, state in grouper.group_state_dict().items():
                print(f"Group {name} - ", end='')
                for module_name, label in state.items():
                    print(f"{module_name}: {label.tolist()}, ", end='')
                print()
        else:
            for name, state in grouper.group_state_dict().items():
                if dom_experts is None:
                    print(f"Group {name}: {state.tolist()}")
                else:
                    print(f"Group {name}: {state.tolist()} (DOMs are {dom_experts[name]}, {len(dom_experts[name])})")

    del grouper
    
    
    print("Number of parameters after merging:", model.num_parameters())
    
        
    


    
    
    return model


if __name__ == "__main__":
    Fire(run_deepseek_tamp)
