# Given the same dataset as measuring coactivation hypergraph, evaluate the partition result
# Input: dataset, partition result for desired layers
# Output: evaluation metrics

import torch
import numpy as np
from tqdm import tqdm
from .utils import load_partition_result,calculate_token_freq
from .coactivation_measure_memory import measure_coactivation_graph
from modelscope import AutoModelForCausalLM, AutoTokenizer
from .data_processor import DataProcessor
from modelscope.msdatasets import MsDataset
from torch.utils.data import Dataset, DataLoader
import os

def corres_act_partition(partitions, coactivation_dataframe, token_freq, layer_id ,neurons = 14336):
    expectation = 0
    for part in partitions:
        expectation += (len(part)/neurons)**2
    probs = []
    ranks = []
    for _, row in tqdm(coactivation_dataframe.iterrows(),desc=f'Processing layer-{layer_id}'):
        activated_neurons = row["activated_neurons"]
        token_id = row["token_id"]
        token_rank = token_freq.loc[token_freq["token_id"]==token_id,"rank"].iloc[0]
        token_rank = 0 if token_rank < 4 else token_rank
        ranks.append(token_rank)
        selected_mask = np.zeros(neurons, dtype=bool)
        selected_mask[activated_neurons] = True
        counts = []
        for part in partitions:
            if len(part) == 0:
                continue
            part = np.array(part) -1 # in load partition, neuron_ind start from 1
            count = selected_mask[part].sum()
            counts.append(count)
        P = np.sum(np.array(counts)**2)/(len(activated_neurons)**2)
        probs.append(P)
    return probs,ranks,expectation

def evaluate_model_partition(model_path, dataset_name, dataset_subset, data_split, partitions_folder):

    # load model 
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto')

    # prepare dataloader
    dataset = MsDataset.load(dataset_name, subset_name=dataset_subset, split=data_split, trust_remote_code=True)
    if not hasattr(dataset, 'select'):
        dataset = dataset.to_hf_dataset()
    # define data processor
    data_processor = DataProcessor(dataset_name)
    train_dataset = dataset.select(range(1000))
    train_dataset = train_dataset.map(
        lambda x: data_processor.format_and_tokenize(x, tokenizer, 512),
        batched=True,
        )
    train_dataset.set_format(type='torch')
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=False,
        # collate_fn=lambda batch: collate_fn(batch, tokenizer)
    )

    #measure token freq
    token_freq_df = calculate_token_freq(train_dataloader, tokenizer)

    # measure coactivation dataframes
    coactivation_dataframes = measure_coactivation_graph(model, train_dataloader, tokenizer)

    probs_result = []
    ranks_result = []
    expectation_result = []
    for layer in range(model.config.num_hidden_layers):

        # load partitions
        partition_result_path = os.path.join(partitions_folder, f'layer_{layer}.parti')
        partitions = load_partition_result(partition_result_path)

        # evaluate partitions
        probs, ranks, expectation = corres_act_partition(partitions, coactivation_dataframes[layer], token_freq_df, layer, neurons=model.config.intermediate_size)
        # token_weighted_average = np.sum(np.array(probs)*np.array(ranks))/np.sum(np.array(ranks))
        probs_result.append(probs)
        ranks_result.append(ranks)
        expectation_result.append(expectation)

    del(model)
    torch.cuda.empty_cache()

    return probs_result, ranks_result, expectation_result