import os
import sys
import json
import time
import torch
import pickle
import logging
import itertools
from fire import Fire
from tqdm import tqdm
from typing import Optional
from transformers import Qwen2MoeForCausalLM, AutoTokenizer, AutoModelForCausalLM

from eval import get_minipile_dataloder, evaluate_minipile_perplexity, evaluate_fewshot, get_calib_dataloder
from .merging.grouping_qwen import (
    ExpertsGrouperForQwen2MoE, 
    FineGrainedExpertsGrouperForQwen,
    merge_by_groups_with_usage_weighted, merge_by_groups_within_and_across_models, 
    merge_by_feature_selection
)

logger = logging.getLogger(__name__)




class TAMPArgs:
    def __init__(
        self,
        task,
        num_average_groups: Optional[int] = None,
        model_name: Optional[str] = "Qwen/Qwen1.5-MoE-A2.7B-Chat",
        dominant: Optional[str] = "knowledge",
        similarity_base: Optional[str] = "router-logits",
        merge: Optional[str] = "zipit",
        mode: Optional[str] = "normal",
        max_block_size: Optional[int] = 2048,
        calib_set: Optional[str] = "c4",
        n_sentences: Optional[int] = 32,
        train_batch_size: Optional[int] = 16,
        eval_batch_size: Optional[int] = 32,
        seed: Optional[int] = 42,
        partition: Optional[int] = 1,
        start_layer: Optional[int] = 0,
        output_path: Optional[str] = None,
        result_path: Optional[str] = None,
        model_path: Optional[str] = None,
        group_limit: Optional[int] = 4,
        data_limit: Optional[int] = 50000,
        num_fewshot: Optional[int] = 0,
        try_oracle: Optional[bool] = False,
        random_start_center: Optional[bool] = False,
        weight: Optional[str] = None,
        cluster: Optional[str] = "kmeans",
        linkage: Optional[str] = "ward",
        hierarchical_stopping_metric: Optional[str] = "silhouette",
        overlap_metric: Optional[str] = "kl-divergence",
        group: Optional[str] = "expert",
        dynamic_group: Optional[bool] = False,
        opt_path: Optional[str] = None,
    ):
        self.task = task
        self.num_average_groups = num_average_groups
        self.model_name = model_name
        self.dominant = dominant
        self.similarity_base = similarity_base
        self.merge = merge
        self.mode = mode
        self.max_block_size = max_block_size
        self.calib_set = calib_set
        self.n_sentences = n_sentences
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.seed = seed
        self.partition = partition
        self.start_layer = start_layer
        self.output_path = output_path
        self.result_path = result_path
        self.model_path = model_path
        self.group_limit = group_limit
        self.data_limit = data_limit
        self.num_fewshot = num_fewshot
        self.try_oracle = try_oracle
        self.random_start_center = random_start_center
        self.weight = weight
        self.cluster = cluster
        self.linkage = linkage
        self.hierarchical_stopping_metric = hierarchical_stopping_metric
        self.overlap_metric = overlap_metric
        self.group = group
        self.dynamic_group = dynamic_group
        self.opt_path = opt_path

def get_dataloader(args, tokenizer):
    return get_calib_dataloder(
        dataset=args.calib_set,
        tokenizer=tokenizer,
        max_block_size=args.max_block_size,
        n_blocks_for_stat=args.n_sentences, 
        batch_size=args.train_batch_size,
        num_workers=8,
        seed=args.seed,
    )

def get_grouper(args, config):
    if args.group == "module":
        return FineGrainedExpertsGrouperForQwen(
            config=config,
            similarity_base=args.similarity_base,
            start_layer=args.start_layer,
            cluster=args.cluster,
            linkage=args.linkage,
            hierarchical_stopping_metric=args.hierarchical_stopping_metric,
        )
    elif args.group == "expert":
        return ExpertsGrouperForQwen2MoE(
                config=config,
                similarity_base=args.similarity_base,
                start_layer=args.start_layer,
                group_limit=args.group_limit,
                data_limit=args.data_limit,
                random_start_center=args.random_start_center,
                cluster=args.cluster,
                linkage=args.linkage,
                hierarchical_stopping_metric=args.hierarchical_stopping_metric,
                overlap_metric=args.overlap_metric,
                dynamic_group=args.dynamic_group,
        )
    else:
        raise ValueError(f"Grouping method `{args.group}` is not supported")

def evaluation(args, model, tokenizer):
    result_dir = args.result_path.split("/")[:-1]
    result_dir = "/".join(result_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    tasks = list(args.task)
    tasks = ["wikitext","winogrande", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte"]
    for i, t in enumerate(tasks):
        evaluate_fewshot(
            model, tokenizer=tokenizer, task=t, num_fewshot=args.num_fewshot, eval_batch_size=args.eval_batch_size, output_path=args.result_path, log=True
        )

def print_usage_frequency(usage_dict):
    for k in usage_dict:
        for num in usage_dict[k]:
            print(round(num.item(), 4), end=',')
        print()

def run_qwen_tamp(
        task: str,
        num_average_groups: Optional[int] = None,
        model_name: Optional[str] = "Qwen/Qwen1.5-MoE-A2.7B-Chat",
        dominant: Optional[str] = "knowledge", 
        group: Optional[str] = "expert", 
        similarity_base: Optional[str] = "router-logits", 
        merge: Optional[str] = "zipit", 
        mode: Optional[str] = "normal", 
        n_sentences: Optional[int] = 32,
        max_block_size: Optional[int] = 2048,
        calib_set: Optional[str] = "c4", 
        train_batch_size: Optional[int] = 16,
        eval_batch_size: Optional[int] = 32,
        seed: Optional[int] = 42,
        partition: Optional[int] = 1,
        start_layer: Optional[int] = 0,
        output_path: Optional[str] = None,
        result_path: Optional[str] = None,
        model_path: Optional[str] = None,
        group_limit: Optional[int] = 4,
        data_limit: Optional[int] = 50000,
        num_fewshot: Optional[int] = 0,
        try_oracle: Optional[bool] = False,
        random_start_center: Optional[bool] = False,
        weight: Optional[str] = None,
        cluster: Optional[str] = "hierarchical",
        linkage: Optional[str] = "ward",
        hierarchical_stopping_metric: Optional[str] = "silhouette",
        ingredient: Optional[str] = "act", 
        overlap_metric: Optional[str] = "cosine", 

        
        dynamic_group: Optional[bool] = False,
        opt_path: Optional[str] = None,
        
        
        model = None,
        tokenizer = None,
):
    logging.info(f"Merge model {model_name} with {num_average_groups} group, {dominant} dominant + {similarity_base} grouping + {merge} {mode} merge with ingredient {ingredient}, evaluate on {task}")
    logging.info(f"Oracle: {try_oracle}, weight: {weight}")
    logging.info(f"Cluster: {cluster}, linkage: {linkage}, hierarchical_stopping_metric: {hierarchical_stopping_metric}, overlap_metric: {overlap_metric}")
    logging.info(f"Group: {group}, dynamic_group: {dynamic_group}, opt_path: {opt_path}.")


    
    args = TAMPArgs(
        task=task,
        num_average_groups=num_average_groups,
        model_name=model_name,
        dominant=dominant,
        similarity_base=similarity_base,
        merge=merge,
        mode=mode,
        max_block_size=max_block_size,
        calib_set=calib_set,
        n_sentences=n_sentences,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        seed=seed,
        partition=partition,
        start_layer=start_layer,
        output_path=output_path,
        result_path=result_path,
        model_path=model_path,
        group_limit=group_limit,
        data_limit=data_limit,
        num_fewshot=num_fewshot,
        try_oracle=try_oracle,
        random_start_center=random_start_center,
        weight=weight,
        cluster=cluster,
        linkage=linkage,
        hierarchical_stopping_metric=hierarchical_stopping_metric,
        overlap_metric=overlap_metric,
        group=group,
        dynamic_group=dynamic_group,
        opt_path=opt_path,
    )
    torch.manual_seed(args.seed)

    eval_ppl = (args.task == "minipile")
    
    
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token_id = tokenizer.eos_token_id

    
    
    
    
    
    if model is None:
        model = Qwen2MoeForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16, device_map="auto"
        )
        if model_path:
            model.load_state_dict(torch.load(model_path))
    
    model.eval()
    dataloader_for_merging = get_dataloader(args, tokenizer)

    grouper = get_grouper(args, model.config)

    
    
    
    
    
    logging.info("[TAMP] Number of parameters before merging: %sB", model.num_parameters() / 1e9)
    logging.info("[TAMP] Merging into average %s groups...", num_average_groups)
    group_st = time.time()
    if merge == "freq" or dominant == "frequency" or mode == "freq":
        grouper.compute_all_usages(model, dataloader_for_merging)
        print_usage_frequency(grouper._usage_frequency_state_dict)
    
    
    
    
    

    
    
    dom_experts = None
    
    
    if merge == "fsm" or merge == "no":
        pass
    
    
    elif dominant == "random":
        grouper.group_experts_randomly(num_groups=args.num_average_groups)
        dom_experts = None

    
    
    elif dominant == "frequency":
        grouper.compute_all_similarities(model, dataloader_for_merging)
        dom_experts = grouper.group_experts_globally_from_dominant_experts(
            num_average_groups=num_average_groups, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
        )

    
    
    elif dominant == "routing-score":
        grouper.compute_all_usages(model, dataloader_for_merging, mode="routing-score")
        print_usage_frequency(grouper._usage_frequency_state_dict)
        dom_experts = grouper.group_experts_globally_from_dominant_experts(
            num_average_groups=num_average_groups, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
        )

    
    
    elif dominant == "knowledge":
        grouper.compute_all_similarities(model, dataloader_for_merging)
        model = grouper.all_in_one_knowledge_dominant(
            model=model, 
            dataloader=dataloader_for_merging,
            merge=merge,
            mode=mode,
            num_groups=num_average_groups,
            dominant_alone=False,
            usage_weighted=False,
        )
        dom_experts = grouper.core_experts

    
    
    elif dominant == "no":
        
        dom_experts = grouper.cluster_experts(model=model, dataloader=dataloader_for_merging, num_groups=num_average_groups)
    elif dominant == "best":
        if num_average_groups == 45:
            grouper._group_state_dict = pickle.load(open("./analysis/info/qwen/qwen_45e_w+eo_cluster_random_center_freq_group_dict.pkl", "rb"))
            dom_experts = pickle.load(open("./analysis/info/qwen/qwen_45e_w+eo_cluster_random_center_freq_dom_experts_dict.pkl", "rb"))
        elif num_average_groups == 30:
            grouper._group_state_dict = pickle.load(open("./analysis/info/qwen/qwen_30e_w+eo_cluster_random_center_freq_group_dict.pkl", "rb"))
            dom_experts = pickle.load(open("./analysis/info/qwen/qwen_30e_w+eo_cluster_random_center_freq_dom_experts_dict.pkl", "rb"))
        else:
            raise ValueError(f"Unknown num_average_groups for best clustering: {num_average_groups}")
        
        
        
    
    
    elif dominant == "init":
        
        if num_average_groups == 45:
            grouper.load_init_center_state_dict("analysis/info/qwen/qwen_45e_best_initial_center.pkl")
        elif num_average_groups == 30:
            grouper.load_init_center_state_dict("analysis/info/qwen/qwen_30e_best_initial_center.pkl")
        else:
            raise ValueError(f"Init dominant is only available for 4 groups for {model_name}, but the input is `{num_average_groups}`")
        dom_experts = grouper.cluster_experts(model=model, dataloader=dataloader_for_merging, num_groups=num_average_groups)
    else:
        raise ValueError(f"Unknown dominant: {dominant}")   

    
    if group == "module":
        model = grouper.merge_weighted(model, merge, core_experts=dom_experts)
    elif dominant != "knowledge":
        if merge == "no":
            pass
        elif merge == "freq":
            model = merge_by_groups_with_usage_weighted(
                model, grouper=grouper, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
            )
        elif merge == "fsm":
            model, dom_experts = merge_by_feature_selection(
                model, grouper=grouper, dataloader=dataloader_for_merging, num_groups=num_average_groups, mode=mode
            )
        else:
            model = merge_by_groups_within_and_across_models(
                qwen_model=model,
                grouper=grouper,
                dataloader=dataloader_for_merging,
                merge=merge,
                mode=mode,
                partition=partition,
                core_experts=dom_experts,
                dominant_alone=False,
                usage_weighted=False,
                ingredient=ingredient,
            )
        
    print(f"[TAMP] Merging time: {time.time() - group_st:.2f} seconds")

    
    if merge != "no":
        print(f"[TAMP] ========= Grouping results ========= ")
        if group == "module":
            for name, state in grouper.group_state_dict().items():
                print(f"Group {name} - ", end='')
                for module_name, label in state.items():
                    print(f"{module_name}: {label.tolist()}, ", end='')
                print()
        else:
            for name, state in grouper.group_state_dict().items():
                if dom_experts is None:
                    print(f"Group {name}: {state.tolist()}")
                else:
                    print(f"Group {name}: {state.tolist()} (DOMs are {dom_experts[name]}, {len(dom_experts[name])})")

    del grouper
    
    
    print("[TAMP] Number of parameters after merging:", model.num_parameters())
    
        
    


    
    
    return model


if __name__ == "__main__":
    Fire(run_qwen_tamp)
