import os
import torch
import gc
from tqdm.auto import tqdm

# Include prompt creation helper functions
from utils.eval_utils import fv_to_vocab, n_shot_eval, analogy_eval, composite_eval, one_pair_no_intervention
from utils.opt_utils import one_pair_active
from utils.model_utils import set_seed
from utils.prompt_utils import load_se_benchmarks

torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.enabled = False

blank = {"input":"", "output":""}

def vector_evaluation(datasetnames, datasets, fv_vectors, model, model_config, tokenizer, 
                      fv_tokens, all_zs_results, all_gold_results, all_shuffled_results, 
                      zs_probs, seed, aff=None, topk=0, n_shots=10,
                      prefixes=None, separators=None, edit_layer:int=-1, semeval=False,
                      filter_sets=None):
    """
    Computes and evaluates a function vector reconstruction which matches its output vocabulary distribution.
    
    Parameters:
    n_steps: number of optimization steps
    lr: adam learning rate
    n_seeds: number of seeds to run
    n_trials: number of prompts to compute task-conditioned mean head activations over
    n_shots: number of shots for task-conditioned mean prompts
    restrict_vocab_list: list of ints determining how many vocab words to match. Defaults to 100 & full-vocab (which is 50400 for GPT-J)
    return_vecs: whether to return the function vectors and their corresponding vocab-optimized reconstruction vectors

    Returns:
    orig_results: FV results
    zs_results: 
    kl_divs: kl divergences between the distribution of the FV and its reconstruction
    fvs: (optional) the function vectors used
    vns: (optional) the vocab-optimized reconstruction vectors
    """

    seeds = []
    lay = edit_layer

    avgs, og_avgs = [[], [], []], [[], [], []]
    for di in range(len(datasetnames)):
        dataset_name, fv_vector = datasetnames[di], fv_vectors[di]
        dataset = datasets[dataset_name]
        filter_set = None if filter_sets is None else filter_sets[dataset_name]

        set_seed(seed)
        # Disable gradients when extracting activations & computing FV 
        torch.set_grad_enabled(False)
        
        # Decoded FV
        token_res = fv_to_vocab(fv_vector, model, model_config, tokenizer)
        for tres in token_res:
            tok, sco = tres
            if tok not in fv_tokens[dataset_name]: fv_tokens[dataset_name][tok] = sco
            else: fv_tokens[dataset_name][tok] += sco

        # Shuffled-label Evaluation
        shuffled_results = n_shot_eval(dataset=dataset, fv_vector=fv_vector, edit_layer=lay,
                                n_shots=n_shots, mean_act=True, model=model, model_config=model_config,
                                prefixes=prefixes, separators=separators, tokenizer=tokenizer,
                                dataname=dataset_name, filter_set=filter_set, shuffle_labels=True)
        avgs[0].append(shuffled_results['intervention_topk'][topk-1][1])
        og_avgs[0].append(shuffled_results['clean_topk'][topk-1][1])
        all_shuffled_results[dataset_name].append(shuffled_results)

        # Zero-shot Evaluation
        set_seed(seed)
        zs_results = n_shot_eval(dataset=dataset, fv_vector=fv_vector, edit_layer=lay,
                                n_shots=0, mean_act=True, model=model, model_config=model_config,
                                prefixes=prefixes, separators=separators, tokenizer=tokenizer,
                                dataname=dataset_name, filter_set=filter_set)
        avgs[1].append(zs_results['intervention_topk'][topk-1][1])
        og_avgs[1].append(zs_results['clean_topk'][topk-1][1])
        all_zs_results[dataset_name].append(zs_results)
        
        # Golden-pair Analogy Evaluation (SemEval)
        if semeval:
            set_seed(seed)
            gold_results = analogy_eval(dataset=dataset, fv_vector=fv_vector, edit_layer=lay, mean_act=True,
                                        model=model, model_config=model_config,
                                        prefixes=prefixes, separators=separators, tokenizer=tokenizer,
                                        dataname=dataset_name)
            avgs[2].append(gold_results['intervention_topk'][topk-1][1])
            og_avgs[2].append(gold_results['clean_topk'][topk-1][1])
            all_gold_results[dataset_name].append(gold_results)

        gc.collect()
        torch.cuda.empty_cache()
        
    # Human Similarity
    set_seed(seed)
    beog_sim, befv_sim= None, None
    if semeval:
        all_benches, _, all_benchrels = load_se_benchmarks()
        beog_acts, be_acts = [], []
        for p in tqdm(range(len(all_benches["bench"])), total=len(all_benches["bench"]), desc="Benchmark (FFV)", leave=False, ncols=60):
            target_bench = all_benches["bench"][p]
            bench_fv = fv_vectors[datasetnames.index(all_benchrels[p])]

            beog_results = one_pair_no_intervention(data=[blank], active_layer=lay,
                                        model=model, model_config=model_config, tokenizer=tokenizer,
                                        context=target_bench)
            beog_acts.append(beog_results["activation"][0])

            be_results = one_pair_active(data=[blank], fv_vector=bench_fv, edit_layer=lay,
                                        model=model, model_config=model_config, tokenizer=tokenizer,
                                        context=target_bench)
            be_acts.append(be_results["activation"][0])
        beog_sim = torch.corrcoef(torch.stack(beog_acts)).cpu()
        befv_sim = torch.corrcoef(torch.stack(be_acts)).cpu()
    
    gc.collect()
    torch.cuda.empty_cache()

    return fv_tokens, all_zs_results, all_gold_results, all_shuffled_results, \
        zs_probs, beog_sim, befv_sim, avgs, og_avgs