import os
import sys
import json
import time
import torch
import pickle
import logging
import itertools
from fire import Fire
from tqdm import tqdm
from typing import Optional
from transformers import MixtralForCausalLM, AutoTokenizer

from eval import evaluate_fewshot, get_calib_dataloder
from .merging import (
    ExpertsGrouperForMixtral,
    FineGrainedExpertsGrouperForMixtral,
    merge_mixtral_by_groups_with_usage_weighted,
    merge_mixtral_by_groups_within_and_across_models,
    merge_mixtral_by_feature_selection,
    merge_mixtral_by_norm_drop_fre,
    merge_mixtral_by_average_weight
)
from model import PrunableMixtralSparseMoeBlockWrapper
from model.mixtral import MyMixtralForCausalLM
from utils import save_json

logger = logging.getLogger(__name__)

from typing import Optional





class TAMPArgs:
    def __init__(
        self,
        task,
        num_average_groups: Optional[int] = None,
        model_name: Optional[str] = "mistralai/Mixtral-8x7B-v0.1",
        dominant: Optional[str] = "knowledge",
        similarity_base: Optional[str] = "router-logits",
        merge: Optional[str] = "zipit",
        mode: Optional[str] = "normal",
        max_block_size: Optional[int] = 2048,
        calib_set: Optional[str] = "c4",
        n_sentences: Optional[int] = 32,
        train_batch_size: Optional[int] = 16,
        eval_batch_size: Optional[int] = 32,
        seed: Optional[int] = 42,
        partition: Optional[int] = 1,
        start_layer: Optional[int] = 0,
        output_path: Optional[str] = None,
        result_path: Optional[str] = None,
        model_path: Optional[str] = None,
        group_limit: Optional[int] = 4,
        data_limit: Optional[int] = 50000,
        num_fewshot: Optional[int] = 0,
        try_oracle: Optional[bool] = False,
        random_start_center: Optional[bool] = False,
        weight= None,
        cluster: Optional[str] = "kmeans",
        linkage: Optional[str] = "ward",
        hierarchical_stopping_metric: Optional[str] = "silhouette",
        overlap_metric: Optional[str] = "cosine",
        group: Optional[str] = "expert", 
        dynamic_group: Optional[bool] = False,
        opt_path: Optional[str] = None,
    ):
        self.task = task
        self.num_average_groups = num_average_groups
        self.model_name = model_name
        self.dominant = dominant
        self.similarity_base = similarity_base
        self.merge = merge
        self.mode = mode
        self.max_block_size = max_block_size
        self.calib_set = calib_set
        self.n_sentences = n_sentences
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.seed = seed
        self.partition = partition
        self.start_layer = start_layer
        self.output_path = output_path
        self.result_path = result_path
        self.model_path = model_path
        self.group_limit = group_limit
        self.data_limit = data_limit
        self.num_fewshot = num_fewshot
        self.try_oracle = try_oracle
        self.random_start_center = random_start_center
        self.weight = weight
        self.cluster = cluster
        self.linkage = linkage
        self.hierarchical_stopping_metric = hierarchical_stopping_metric
        self.overlap_metric = overlap_metric
        self.group = group
        self.dynamic_group = dynamic_group
        self.opt_path = opt_path

def get_dataloader(args, tokenizer):
    return get_calib_dataloder(
        dataset=args.calib_set,
        tokenizer=tokenizer,
        max_block_size=2048,
        n_blocks_for_stat=args.n_sentences, 
        batch_size=args.train_batch_size,
        num_workers=8,
        seed=args.seed,
    )

def get_grouper(args, config):
    if args.group == "expert":
        return ExpertsGrouperForMixtral(
            config=config,
            similarity_base=args.similarity_base,
            start_layer=args.start_layer,
            group_limit=args.group_limit,
            data_limit=args.data_limit,
            random_start_center=args.random_start_center,
            cluster=args.cluster,
            linkage=args.linkage,
            hierarchical_stopping_metric=args.hierarchical_stopping_metric,
            overlap_metric=args.overlap_metric,
            dynamic_group=args.dynamic_group,
            opt_path=args.opt_path,
        )
    elif args.group == "module":
        return FineGrainedExpertsGrouperForMixtral(
            config=config,
            similarity_base=args.similarity_base,
            start_layer=args.start_layer,
            cluster=args.cluster,
            linkage=args.linkage,
            hierarchical_stopping_metric=args.hierarchical_stopping_metric,
        )
    else:
        raise ValueError(f"Grouping method `{args.group}` is not supported")

def evaluation(args, model, tokenizer):
    result_dir = args.result_path.split("/")[:-1]
    result_dir = "/".join(result_dir)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    tasks = list(args.task)
    tasks = ["wikitext","winogrande", "arc_challenge", "arc_easy", "boolq", "hellaswag", "mmlu", "openbookqa", "rte"]
    for i, t in enumerate(tasks):
        evaluate_fewshot(
            model, tokenizer=tokenizer, task=t, num_fewshot=args.num_fewshot, eval_batch_size=args.eval_batch_size, output_path=args.result_path, log=True
        )

def print_usage_frequency(usage_dict):
    for k in usage_dict:
        for num in usage_dict[k]:
            print(round(num.item(), 4), end=',')
        print()




def oracle_test(
        args,
        model: MixtralForCausalLM,
        grouper: ExpertsGrouperForMixtral,
        dataloader: torch.utils.data.DataLoader,
):
    
    original_devices = {}
    for l, layer in enumerate(model.model.layers):
        original_devices[l] = next(layer.block_sparse_moe.parameters()).device
    print(original_devices)

    
    oracle = pickle.load(open("oracle/result_tinyllama_r2_n32.pkl", "rb"))
    
    
    for k in oracle:
        print(k, oracle[k])
    grouper.compute_all_usages(model, dataloader)
    usage_frequency_dict = grouper.usage_frequency_state_dict()
    
    for l, layer in enumerate(model.model.layers):
        ffn_name = f"model.layers.{l}.block_sparse_moe"
        layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
            layer.block_sparse_moe, 
            dom_experts=oracle[l], 
            r=args.num_average_groups, 
            usage_freq=usage_frequency_dict[ffn_name],
            merge_method=args.merge, 
            mode=args.mode, 
            weight=args.weight,
        )
        layer.block_sparse_moe.cache_X = True
        layer.block_sparse_moe.cache_Z = True
        layer.block_sparse_moe.cache_R = True
    
    with torch.inference_mode():
        for i, batch in enumerate(tqdm(dataloader, desc='Model forwarding on sample set...')):
            batch = {k: v.to('cuda') for k, v in batch.items()}
            model_inputs = model.prepare_inputs_for_generation(**batch)
            outputs = model(**model_inputs)
            assert outputs is not None
    logger.info('Moving whole model to cpu...')
    
    for l, layer in enumerate(model.model.layers):
        layer = layer.to('cpu')
    
    

    global_loss_history = dict()
    for l, layer in tqdm(list(enumerate(model.model.layers)), desc='Enumerating loss on sample set...'):
        print("\n---layer ", l)
        b = layer.block_sparse_moe
        if not hasattr(b, 'cache_space'):
            continue
        
        b.to(original_devices[l])
        loss_history = b.enumerate()
        print(loss_history)
        global_loss_history[l] = loss_history
        b.to('cpu')
    
    logger.info('Merging & saving...')
    for l, layer in enumerate(model.model.layers):
        layer.block_sparse_moe = layer.block_sparse_moe.model
        layer = layer.to(original_devices[l])
    return model

def global_prune(args, tokenizer):
    model = MyMixtralForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch.bfloat16, device_map="auto"
    )
    model.eval()
    grouper = get_grouper(args, model.config)
    dataloader_for_merging = get_dataloader(args, tokenizer)
    model, reserve_dict, global_loss = grouper.prune_experts_in_moe(model=model, dataloader=dataloader_for_merging)
    save_json(reserve_dict, args.result_path.split(".")[0]+ "_reserve_dict.json")
    save_json(global_loss, args.result_path.split(".")[0] + "_global_loss.json")
    evaluation(args, model, tokenizer)





def run_mixtral_tamp(
        task: str,
        num_average_groups: Optional[int] = None,
        model_name: Optional[str] = "mistralai/Mixtral-8x7B-v0.1",
        dominant: Optional[str] = "no", 
        group: Optional[str] = "expert", 
        similarity_base: Optional[str] = "router-logits", 
        merge: Optional[str] = "norm_drop_freq", 
        mode: Optional[str] = "normal", 
        n_sentences: Optional[int] = 32,
        max_block_size: Optional[int] = 2048,
        calib_set: Optional[str] = "c4", 
        train_batch_size: Optional[int] = 16,
        eval_batch_size: Optional[int] = 32,
        seed: Optional[int] = 42,
        partition: Optional[int] = 1,
        start_layer: Optional[int] = 0,
        output_path: Optional[str] = None,
        result_path: Optional[str] = None,
        model_path: Optional[str] = None,
        group_limit: Optional[int] = 4,
        data_limit: Optional[int] = 50000,
        num_fewshot: Optional[int] = 0,
        try_oracle: Optional[bool] = False,
        random_start_center: Optional[bool] = False,
        weight: Optional[str] = None,
        cluster: Optional[str] = "hierarchical",
        linkage: Optional[str] = "ward",
        hierarchical_stopping_metric: Optional[str] = "silhouette",
        ingredient: Optional[str] = "act", 
        overlap_metric: Optional[str] = "cosine", 

        
        dynamic_group: Optional[bool] = False,
        opt_path: Optional[str] = None,
        
        
        model = None,
        tokenizer = None,
):
    logging.info(f"Merge model {model_name} with {num_average_groups} group, {dominant} dominant + {similarity_base} grouping + {merge} {mode} merge with ingredient {ingredient}, evaluate on {task}")
    logging.info(f"Oracle: {try_oracle}, weight: {weight}")
    logging.info(f"Cluster: {cluster}, linkage: {linkage}, hierarchical_stopping_metric: {hierarchical_stopping_metric}, overlap_metric: {overlap_metric}")
    logging.info(f"Group: {group}, dynamic_group: {dynamic_group}, opt_path: {opt_path}.")


    
    args = TAMPArgs(
        task=task,
        num_average_groups=num_average_groups,
        model_name=model_name,
        dominant=dominant,
        similarity_base=similarity_base,
        merge=merge,
        mode=mode,
        max_block_size=max_block_size,
        calib_set=calib_set,
        n_sentences=n_sentences,
        train_batch_size=train_batch_size,
        eval_batch_size=eval_batch_size,
        seed=seed,
        partition=partition,
        start_layer=start_layer,
        output_path=output_path,
        result_path=result_path,
        model_path=model_path,
        group_limit=group_limit,
        data_limit=data_limit,
        num_fewshot=num_fewshot,
        try_oracle=try_oracle,
        random_start_center=random_start_center,
        weight=weight,
        cluster=cluster,
        linkage=linkage,
        hierarchical_stopping_metric=hierarchical_stopping_metric,
        overlap_metric=overlap_metric,
        group=group,
        dynamic_group=dynamic_group,
        opt_path=opt_path,
    )
    
    torch.manual_seed(args.seed)
    
    model.eval()
    dataloader_for_merging = get_dataloader(args, tokenizer)
    grouper = get_grouper(args, model.config)

    if try_oracle:
        model = oracle_test(args, model, grouper, dataloader_for_merging)
        evaluation(args, model, tokenizer)
        sys.exit(0)
    
    logging.info("Number of parameters before merging: %sB", model.num_parameters() / 1e9)
    logging.info("Merging into average %s groups...", num_average_groups)
    group_st = time.time()
    
    
    
    
    
    
    
    

    
    
    dom_experts = None
    
    
    if merge == "fsm" or merge == "no":
        pass
    
    
    elif dominant == "random":
        grouper.group_experts_randomly(num_groups=args.num_average_groups)
        dom_experts = None

    
    
    elif dominant == "frequency":
        grouper.compute_all_similarities(model, dataloader_for_merging)
        dom_experts = grouper.group_experts_globally_from_dominant_experts(
            num_average_groups=num_average_groups, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
        )

    
    
    elif dominant == "routing-score":
        grouper.compute_all_usages(model, dataloader_for_merging, mode="routing-score")
        print_usage_frequency(grouper._usage_frequency_state_dict)
        dom_experts = grouper.group_experts_globally_from_dominant_experts(
            num_average_groups=num_average_groups, merging_layers=list(range(start_layer, model.config.num_hidden_layers))
        )

    
    
    elif dominant == "knowledge":
        grouper.compute_all_similarities(model, dataloader_for_merging)
        model = grouper.all_in_one_knowledge_dominant(
            model=model, 
            dataloader=dataloader_for_merging,
            merge=merge,
            mode=mode,
            num_groups=num_average_groups,
            dominant_alone=False,
            usage_weighted=False,
        )
        dom_experts = grouper.core_experts

    
    
    elif dominant == "no":
        time_start = time.time()
        dom_experts = grouper.cluster_experts(model=model, dataloader=dataloader_for_merging)
        print(f"cluster_experts time: {time.time() - time_start}")

    
    
    elif dominant == "best":
        if model_name == "s3nh/TinyLLama-4x1.1B-MoE":
            grouper._group_state_dict = pickle.load(open("analysis/info/tinyllama/tinyllama_eo_cluster_group_dict.pkl", "rb"))
            dom_experts = pickle.load(open("analysis/info/tinyllama/tinyllama_eo_cluster_dom_experts_dict.pkl", "rb"))
        elif num_average_groups == 4:
            grouper._group_state_dict = pickle.load(open("analysis/info/mixtral/mixtral_4e_eo_cluster_random_center_freq_group_dict.pkl", "rb"))
            dom_experts = pickle.load(open("analysis/info/mixtral/mixtral_4e_eo_cluster_random_center_freq_dom_experts_dict.pkl", "rb"))
        else:
            raise ValueError(f"Best dominant is only available for 4 groups for {model_name}, but the input is `{num_average_groups}`")
        
        
        
    
    
    elif dominant == "init":
        
        if num_average_groups == 4 and model_name != "s3nh/TinyLLama-4x1.1B-MoE":
            grouper.load_init_center_state_dict("analysis/info/mixtral/mixtral_4e_best_initial_center.pkl")
        else:
            raise ValueError(f"Init dominant is only available for 4 groups for Mixtral8x7B, but the input is `{num_average_groups}` with {model_name}")
        dom_experts = grouper.cluster_experts(model=model, dataloader=dataloader_for_merging, num_groups=num_average_groups)
    else:
        raise ValueError(
            f"Accepted dominant are `random`, `frequency`, `knowledge`, `no`, but the input is `{dominant}`")

    
    
    if merge == "no":
        pass
    elif group == "module":
        model = grouper.merge_weighted(model, merge, core_experts=dom_experts)
    elif dominant != "knowledge":
        
        
        if merge == "freq":
            
            model = merge_mixtral_by_groups_with_usage_weighted(
                model, grouper=grouper, merging_layers=list(range(start_layer, model.config.num_hidden_layers)), calib_dataloader=dataloader_for_merging
            )
        elif merge == "fsm":
            
            model, dom_experts = merge_mixtral_by_feature_selection(
                model, grouper=grouper, dataloader=dataloader_for_merging, num_groups=num_average_groups, mode=mode
            )
        elif merge == "norm_drop_fre":
            
            time_start = time.time()
            model = merge_mixtral_by_norm_drop_fre(
                model,
                grouper=grouper,
            )
            print(f"merge_mixtral_by_norm_drop_fre time: {time.time() - time_start}")
        elif merge == "avg":
            model = merge_mixtral_by_average_weight(
                model, 
                grouper=grouper
            )
        else:
            
            model = merge_mixtral_by_groups_within_and_across_models(
                mixtral_model=model,
                grouper=grouper,
                dataloader=dataloader_for_merging,
                merge=merge,
                mode=mode,
                partition=partition,
                dominant_alone=False,
                core_experts=dom_experts,
                ingredient=ingredient,
            )
    
    logging.info(f"[TAMP] Merging time: {time.time() - group_st:2f} seconds")
    
    
    if merge != "no":
        
        logging.info(f"[TAMP] ========= Grouping results ========= ")
        if group == "module":
            for name, state in grouper.group_state_dict().items():
                logging.info(f"Group {name} - ", end='')
                for module_name, label in state.items():
                    print(f"{module_name}: {label.tolist()}, ", end='')
                print()
        else:
            for name, state in grouper.group_state_dict().items():
                if dom_experts is None:
                    logging.info(f"Group {name}: {state.tolist()}")
                else:
                    logging.info(f"Group {name}: {state.tolist()} (DOMs are {dom_experts[name]})")
        del grouper

    if merge == "unmerge":
        logging.info(f"[TAMP] ======= Grouping of unmerge ======= ")
        for layer_idx in range(start_layer, model.config.num_hidden_layers):
            logging.info(f"--- Layer {layer_idx} ---")
            logging.info(f"expert_to_group: {model.model.layers[layer_idx].block_sparse_moe.expert_to_group}")
            logging.info(f"group_to_expert: {model.model.layers[layer_idx].block_sparse_moe.group_to_expert}")
            logging.info(f"unmerge_matrix: {model.model.layers[layer_idx].block_sparse_moe.unmerge_matrix}")
    
    
    logging.info(f"[TAMP] Number of parameters after merging: {model.num_parameters()/1024/1024/1024}B")
    
    
    
    
    
    
    
    
    
    
    
    
    torch.cuda.empty_cache()


    return model
