import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import argparse
import json
import gc
from tqdm.auto import tqdm
from decimal import Decimal
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from transformers import get_scheduler

# Include prompt creation helper functions
from utils.prompt_utils import load_dataset, load_se_benchmarks
from utils.extract_utils import get_mean_head_activations, compute_indirect_effect, compute_function_vector, compute_universal_function_vector
from utils.eval_utils import n_shot_eval_no_intervention, analogy_eval_no_intervention
from utils.opt_utils import finetune_vector
from utils.model_utils import load_gpt_model_and_tokenizer, set_seed
from ffv_evaluation import vector_evaluation

torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.enabled = False

def vector_optimization(datasetpaths, n_steps:int=5, lr:float=0.05, n_seeds:int=5, n_trials:int=100, n_shots:int=10, n_top_heads:int=10,
                         edit_layer:int=-1, prefixes=None, separators=None,
                         isRand=False, limit=None):
    """
    Computes and evaluates a function vector reconstruction which matches its output vocabulary distribution.
    
    Parameters:
    n_steps: number of optimization steps
    lr: adam learning rate
    n_seeds: number of seeds to run
    n_trials: number of prompts to compute task-conditioned mean head activations over
    n_shots: number of shots for task-conditioned mean prompts

    Returns:
    fv_zs_results: FV 0-shot results
    ffv_zs_results: FFV 0-shot results
    all_fva_sims: cosine similarities of activations following edit_layer
    """

    seeds = []
    ana_results = {k:[] for k in datasetnames}
    full_results = {k:[] for k in datasetnames}
    fv_zs_results = {k:[] for k in datasetnames}
    fv_shuffled_results = {k:[] for k in datasetnames}
    fv_ana_results = {k:[] for k in datasetnames}
    ffv_zs_results = {k:[] for k in datasetnames}
    ffv_shuffled_results = {k:[] for k in datasetnames}
    ffv_ana_results = {k:[] for k in datasetnames}
    fv_probs, ffv_probs = {}, {}
    og_sims, fva_sims, ffva_sims = [], [], []
    fv_tokens = {k:{} for k in datasetnames}
    ffv_tokens = {k:{} for k in datasetnames}
    lay = edit_layer-1
    
    for i in range(n_seeds):
        fvs = []
        typedatasets, datasets = {}, {}
        filter_sets = {k:None for k in datasetnames}
        seed = np.random.randint(100000) if isRand else i+1
        print(f"seed:{seed}")
        # seeds.append(seed)
        fvsuf = '_%dheads_%dshot_%dtrial_%dOptPairs_seed%d%s' % (n_top_heads, n_shots,
                                                                 n_trials, optlim, seed,
                                                                 "_FiltQuery" if filter_query else "")
        avgs = [[], [[], [], []], # Full-shot, Shuffled, FV Shuffled, FFV Shuffled,
                [[], [], []], # 0-shot, FV 0-shot, FFV 0-shot, 
                [[], [], []]] # 1-shot Analogy, FV Analogy, FFV Analogy
        
        print("Initialization...")
        fvdir = subfeat.replace(presub, '').replace("Simple", "abstractive")
        
        filter_set = None
        # for di in range(len(datasetpaths)):
        for di in tqdm(range(len(datasetpaths)), total=len(datasetpaths), ncols=50, desc="Dataset Setup", leave=False):
            dataset_name = datasetpaths[di]
            try:
                # print(f"Dataset: {dataset_name}")

                set_seed(seed)
                # Disable gradients when extracting activations & computing FV 
                torch.set_grad_enabled(False)
                
                typeind = int(dataset_name[:2])-1 if semeval else di
                typename, typelist = typedatasetnames[typeind], typedatasetpaths[typeind]
                typedataset = load_dataset(typename, typelist, seed=seed, istype=True,
                                                optlim=optlim, root_data_dir=root_data_dir, gold=semeval)
                # print(typedataset)
                
                if semeval:
                    set_seed(seed)
                    dataset = load_dataset(dataset_name, [dataset_name], seed=seed, istype=False,
                                                optlim=optlim, root_data_dir=root_data_dir, gold=semeval)
                else: dataset = typedataset
                typedatasets[dataset_name] = typedataset
                datasets[dataset_name] = dataset

                set_seed(seed)
                fs_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=min(n_shots, len(dataset['opt'])),
                                                         model=model, model_config=model_config, tokenizer=tokenizer,
                                                         compute_ppl=True, test_split='test', context_split='opt')
                full_results[dataset_name].append(fs_results)
                avgs[0].append(fs_results['clean_topk'][topk-1][1])

                if filter_query:
                    set_seed(seed)
                    filter_results = n_shot_eval_no_intervention(dataset=dataset, n_shots=n_shots, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=True, prefixes=prefixes, separators=separators)
                    filter_set = np.where(np.array(filter_results['clean_rank_list']) == 0)[0]
                    filter_sets[dataset_name] = filter_set
                else: filter_set = None
                
                if semeval:
                    set_seed(seed)
                    o_results = analogy_eval_no_intervention(dataset=dataset, model=model, model_config=model_config, tokenizer=tokenizer, compute_ppl=False,
                                                            test_split='test', context_split='gold')
                    ana_results[dataset_name].append(o_results)
                    avgs[-1][0].append(o_results['clean_topk'][topk-1][1])

                indirect_effect = None
                if rand_init:
                    fvmu = torch.randn((1,model_config['resid_dim'])).to(device=model.device, dtype=model.dtype)
                else:
                    fv_path = f'{save_path_root}/{model_name}/FVs/{fvdir}/{typename}_FV{fvsuf}.pt'
                    if os.path.exists(fv_path): fvmu = torch.load(fv_path, weights_only=True)
                    else:
                        os.makedirs(f'{save_path_root}/{model_name}/FVs/{fvdir}', exist_ok=True)
                        if filter_query: 
                            set_seed(seed+42)
                            fs_results_validation = n_shot_eval_no_intervention(dataset=typedataset, n_shots=n_shots, model=model, model_config=model_config,
                                                                                tokenizer=tokenizer, compute_ppl=False, test_split='query')
                            filter_set_validation = np.where(np.array(fs_results_validation['clean_rank_list']) == 0)[0]
                        else: filter_set_validation = None
                        set_seed(seed)
                        tqdm.write("Getting mean head activations...", end="\r")
                        mean_activations = get_mean_head_activations(typedataset, model=model, model_config=model_config, tokenizer=tokenizer, n_icl_examples=min(n_shots, len(dataset['train'])),
                                                                        N_TRIALS=n_trials, filter_set=filter_set_validation, split="query")
                        if is_uni:
                            fvmu, _ = compute_universal_function_vector(mean_activations, model, model_config)
                        else:
                            ie_dir = f'{save_path_root}/{model_name}/AIEs/{fvdir}'
                            ie_path = f'{ie_dir}/{typename}_indirect_effect{fvsuf}.pt'
                            if os.path.exists(ie_path): indirect_effect = torch.load(ie_path)
                            elif os.path.exists(ie_path.replace(fvsuf, '_seed%d%s' % (seed, "_FiltQuery" if filter_query else "",))):
                                tqdm.write("Renaming IE...", end="\r")
                                os.rename(ie_path.replace(fvsuf, '_seed%d%s' % (seed, "_FiltQuery" if filter_query else "",)), ie_path)
                                indirect_effect = torch.load(ie_path)
                            else:
                                os.makedirs(ie_dir, exist_ok=True)
                                set_seed(seed)
                                indirect_effect = compute_indirect_effect(typedataset, mean_activations, model=model, model_config=model_config, tokenizer=tokenizer,
                                                                        n_shots=min(n_shots, len(dataset['train'])), n_trials=25, last_token_only=True,
                                                                        split="query", filter_set=filter_set_validation)
                                torch.save(indirect_effect, ie_path)

                            tqdm.write("Obtaining function vector...", end="\r")
                            fvmu, _ = compute_function_vector(mean_activations, indirect_effect, model, model_config=model_config, n_top_heads=n_top_heads)
                        torch.save(fvmu, fv_path)
                fvs.append(fvmu)
            except Exception as e:
                print(e)
                print("%s Run %d not working! Skipping..." % (dataset_name, i+1))
                gc.collect()
                torch.cuda.empty_cache()  

        if rand_init:
            print("FV randomly initialized!")
            for avg in avgs: avg[1].append(0)
        else:
            print("FV Evaluation...")
            fv_tokens, fv_zs_results, fv_ana_results, fv_shuffled_results, \
            fv_probs, _, fvbea_sims, ifv_avgs, _ = \
                vector_evaluation(datasetnames, datasets, fvs, model, model_config, tokenizer,
                                fv_tokens, fv_zs_results, fv_ana_results, fv_shuffled_results, fv_probs,
                                seed, n_shots=n_shots, edit_layer=lay, semeval=semeval,
                                topk=topk, filter_sets=filter_sets)

            fva_sims.append(fvbea_sims)
            for ifv in range(len(ifv_avgs)):
                avgs[ifv+1][1] += ifv_avgs[ifv]
        
        fvn = len(fvs)
        opt = torch.optim.AdamW(fvs, lr=lr, weight_decay=l2w)
        num_training_steps = n_steps * fvn * 10
        sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, num_training_steps, eta_min=lr/10)
        # if affine and semeval:
        #     aff = nn.Linear(fvn, fvn, device=model.device, dtype=model.dtype)
        #     opta = torch.optim.AdamW(aff.parameters(), lr=lra, weight_decay=l2a)
        #     num_training_steps_a = n_steps * 1000
        #     sched_a = torch.optim.lr_scheduler.CosineAnnealingLR(opta, num_training_steps_a)

        print("Optimization...")
        sufreps = [["shot_", "shot_lr%s_" % str(lr).replace(".", "")],
                   ["OptPairs_","OptPairs_%sCE_%sKL_%depoch_" % (str(Decimal(str(cew)).normalize()),
                                                                str(Decimal(str(klw)).normalize()),
                                                                n_steps,)]]
        ffvsuf = fvsuf
        for sf in sufreps: ffvsuf = ffvsuf.replace(sf[0], sf[1])
        
        ffvpre = ""
        if rand_init: ffvpre = "rand_" + ffvpre
        ffv_path_suf = f'Finetuned_FV{ffvsuf}.pt'
        ffv_path_pre = f'{save_path_root}/{model_name}/Finetuned_FVs/{fvdir}/{ffvpre}'
        ffv_paths = [f'{ffv_path_pre}{datasetnames[di]}_{ffv_path_suf}' for di in range(len(datasetnames))]
        try:
            if update: raise StopIteration("FFV needs to be updated! Optimizing from FV...")
            finetuned_fvs = [torch.load(ffv_path, weights_only=True) for ffv_path in ffv_paths]
        except Exception as e:
            print("Optimizing FV...")
            os.makedirs(f'{save_path_root}/{model_name}/Finetuned_FVs/{fvdir}', exist_ok=True)
            epochs = n_steps
            if ckpt > 0 and epochs > ckpt:
                try:
                    ck_paths = [fpath.replace("%depoch" % (n_steps,), "%depoch" % (ckpt,)) 
                                for fpath in ffv_paths]
                    cks = [torch.load(ck_path, weights_only=True) for ck_path in ck_paths]
                    fvs = cks
                    epochs -= ckpt
                except Exception as e:
                    print(e)
                    print("Can't extract from checkpoint! Starting from scratch...")
            
            # Enable Gradients for Optimization
            torch.set_grad_enabled(True)
            for p in model.parameters(): p.requires_grad = False
            for fvu in fvs: fvu.requires_grad=True
            finetuned_fvs = finetune_vector(datasets=datasets, fv_vectors=fvs,
                                            edit_layer=lay, n_steps=epochs,
                                            prefixes=prefixes, separators=separators, 
                                            weight=cew, klweight=klw,
                                            model=model, model_config=model_config,
                                            tokenizer=tokenizer, optimizer=opt,
                                            lr_scheduler=sched, split="opt", oglim=limit)
            if not rand_init:
                for di in range(len(datasetnames)): torch.save(finetuned_fvs[di], ffv_paths[di])
            for fvu in fvs: fvu.requires_grad=False
            torch.set_grad_enabled(False)
        aff = None
        
        print("FFV Evaluation...")
        ffv_tokens, ffv_zs_results, ffv_ana_results, ffv_shuffled_results, \
            ffv_probs, beog_sims, ffvbea_sims, ffv_avgs, og_avgs = \
            vector_evaluation(datasetnames, datasets, finetuned_fvs, model, model_config, tokenizer,
                                ffv_tokens, ffv_zs_results, ffv_ana_results, ffv_shuffled_results, ffv_probs, seed, aff=aff,
                                prefixes=prefixes, separators=separators, edit_layer=lay, semeval=semeval,
                                topk=topk, filter_sets=filter_sets)

        og_sims.append(beog_sims)
        ffva_sims.append(ffvbea_sims)
        for ff in range(len(ffv_avgs)):
            avgs[ff+1][0] += og_avgs[ff]
            avgs[ff+1][2] += ffv_avgs[ff]
        
        for di in range(len(datasetpaths)):
            dataset_name = datasetnames[di]
            dataset = datasets[dataset_name]

            print(f"Dataset: {dataset_name}")
            print('fullshot_topk:', full_results[dataset_name][-1]['clean_topk'], 'fullshot_rank_list:', full_results[dataset_name][-1]['clean_rank_list'][:15])
            if not rand_init: print('FV_0shot_topk:', fv_zs_results[dataset_name][-1]['intervention_topk'], 'FV_0shot_rank_list:', fv_zs_results[dataset_name][-1]['intervention_rank_list'][:15])
            print('FFV_0shot_topk:', ffv_zs_results[dataset_name][-1]['intervention_topk'], 'FFV_0shot_rank_list:', ffv_zs_results[dataset_name][-1]['intervention_rank_list'][:15])
            if semeval:
                print('Analogy_topk:', ana_results[dataset_name][-1]['clean_topk'], 'Analogy_rank_list:', ana_results[dataset_name][-1]['clean_rank_list'])
                if not rand_init: print('FV_analogy_topk:', fv_ana_results[dataset_name][-1]['intervention_topk'], 'FV_analogy_rank_list:', fv_ana_results[dataset_name][-1]['intervention_rank_list'])
                print('FFV_analogy_topk:', ffv_ana_results[dataset_name][-1]['intervention_topk'], 'FFV_analogy_rank_list:', ffv_ana_results[dataset_name][-1]['intervention_rank_list'])
        gc.collect()
        torch.cuda.empty_cache()

        print(f"Averages: 10-shot={np.mean(avgs[0])}, 0-Shot={np.mean(avgs[2][0])}, FV 0-Shot={np.mean(avgs[2][1])}, FFV 0-shot={np.mean(avgs[2][2])}")
        print(f"Averages: Shuffled-label={np.mean(avgs[1][0])}, FV Shuffled={np.mean(avgs[1][1])}, FFV Shuffled={np.mean(avgs[1][2])}")
        if semeval: print(f"Averages: Analogy={np.mean(avgs[3][0])}, FV Analogy={np.mean(avgs[3][1])}, FFV Analogy={np.mean(avgs[3][2])}")

    all_fva_sims = [og_sims, fva_sims, ffva_sims]

    return {"full_results":full_results, "ana_results":ana_results,
            "fv_zs_results":fv_zs_results, "fv_shuffled_results":fv_shuffled_results, "fv_ana_results":fv_ana_results,
            "ffv_zs_results":ffv_zs_results, "ffv_shuffled_results":ffv_shuffled_results, "ffv_ana_results":ffv_ana_results,
            "all_fva_sims": all_fva_sims
        }
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
    parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
    parser.add_argument('--data_dir', help='Directory of dataset', type=str, required=False, default='Simple')
    parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results')
    parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function vector', required=False, type=int, default=10)
    parser.add_argument('--edit_layer', help='Layer to apply optimized FV on', required=False, type=int, default=-1)
    parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5)
    parser.add_argument('--n_trials', help='Number of trials to use for computing task-conditioned mean head activations', type=int, required=False, default=100)
    parser.add_argument('--n_shots', help='Number of shots to use for prompts when computing task-conditioned mean head activations', type=int, required=False, default=5)
    parser.add_argument('--lr', help="Learning Rate for AdamW Optimizer", type=float, required=False, default=0.05)
    parser.add_argument('--n_steps', help="Number of epochs", type=int, required=False, default=5)
    parser.add_argument('--cew', help="CE weight", type=float, required=False, default=1)
    parser.add_argument('--klw', help="KL weight", type=float, required=False, default=0)
    parser.add_argument('--l2w', help="L2 weight", type=float, required=False, default=0.01)
    parser.add_argument('--l2a', help="L2 weight for affine", type=float, required=False, default=0.01)
    parser.add_argument('--is_rand', help="Determines if seeds should be randomized", action='store_true')
    parser.add_argument('--prompt', help='Prompt template prefixes to be used', type=str, required=False, default="QA")
    parser.add_argument('--is_uni', help='Whether to use universal FV', action='store_true')
    parser.add_argument('--topk', help='Returns accuracy for top-k (must be less than 5)', type=int, required=False, default=0)
    parser.add_argument('--optlim', help='Determines number of pairs for optimization', type=int, required=False, default=10)
    parser.add_argument('--ckpt', help='Train from checkpoint', type=int, required=False, default=0)
    parser.add_argument('--filter_query', help="Filter for initial FV", action='store_true')
    parser.add_argument('--rand_init', help="Determines if FV should be randomly initialized", action='store_true')
    parser.add_argument('--update', help="Updates FFV", action='store_true')
    # parser.add_argument('--affine', help="Uses affine transform", action='store_true')
    # parser.add_argument('--lra', help="Learning Rate for affine optimizer", type=float, required=False, default=0.01)
    # parser.add_argument('--ackpt', help='Train from affine checkpoint', type=int, required=False, default=0)
    # parser.add_argument('--n_steps_a', help="Number of epochs for affine", type=int, required=False, default=25)
    # parser.add_argument('--updatea', help="Updates affine transformation", action='store_true')
    parser.add_argument('--debug', help="Sets to debug mode", action='store_true')

    args = parser.parse_args()
    print(args)
    
    # Gather inputs
    modname = args.model_name
    model_name = modname.split('/')[-1]
    root_data_dir, data_dir = args.root_data_dir, args.data_dir
    save_path_root = args.save_path_root
    n_seeds, n_trials, n_shots = args.n_seeds, args.n_trials, args.n_shots
    lr = args.lr
    cew, klw = args.cew, args.klw
    l2w = args.l2w
    if l2w == 0: l2w = 0.01 # The default
    n_steps = args.n_steps
    n_top_heads = args.n_top_heads
    edit_layer = args.edit_layer
    is_rand = args.is_rand
    semeval = data_dir == "SemEval"
    prompt = args.prompt
    update = args.update
    ckpt = args.ckpt
    filter_query = False if semeval else args.filter_query
    is_uni = args.is_uni
    optlim = args.optlim if semeval else 0
    rand_init = args.rand_init
    # affine = args.affine
    # if affine:
    topk = args.topk if semeval else 1
    debug = args.debug
    if debug: n_seeds = 1
    
    # Load Model & Tokenizer
    torch.set_grad_enabled(False)
    model, tokenizer, model_config = load_gpt_model_and_tokenizer(modname)
    if tokenizer.pad_token is None and 'llama' in model_name.lower():
        tokenizer.pad_token = tokenizer.ezs_token

    assert(prompt in ["QA", "AB", "Analogy"])
    prefixes = {"input":"Q:", "output":"A:", "instructions":""}
    separators = {"input":"\n", "output":"\n\n", "instructions":""}
    if prompt == "Analogy":
        prefixes = {"input":"", "output":":", "instructions":""}
        separators["input"] = ""
    elif prompt == "AB":
        prefixes = {"input":"A:", "output":"B:", "instructions":""}
    
    from itertools import groupby
    if semeval:
        alldatasetpaths = [os.path.splitext(dat)[0] for dat in sorted(os.listdir(os.path.join(root_data_dir, "SemEval")))]
        typedatasetpaths = [list(v) for k,v in groupby(sorted(alldatasetpaths), lambda x: x[:2])]
        typedatasetnames = [os.path.splitext(dat)[0] for dat in sorted(os.listdir("../dataset_files/SemEvalType"))]
    else:
        alldatasetpaths = ['antonym', 'capitalize', 'country-capital', 'english-french', 'present-past', 'singular-plural'] \
            if data_dir == 'Simple' else [os.path.splitext(dat)[0] for dat in sorted(os.listdir(os.path.join(root_data_dir,data_dir)))]
        typedatasetpaths = [[v] for v in alldatasetpaths]
        typedatasetnames = alldatasetpaths
    datasetpaths = alldatasetpaths
    datasetnames = alldatasetpaths

    args.datasetpaths = datasetpaths

    # Set up result outline
    import pandas as pd
    presub = f'{prompt}_' if prompt in [f"AB", "Analogy"] else ""
    subfeat = presub + data_dir
    laysuf = "_layer%d" % (edit_layer,) if edit_layer != -1 else ""
    laypre = f'{edit_layer}/{subfeat}/' if edit_layer != -1 else f'All_Layers/{subfeat}/'

    os.makedirs(os.path.join(save_path_root, '%s/accuracy/FFV_accs/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/cosine_similarity/%s' % (model_name,laypre,)), exist_ok=True)

    pre, suf = "", ""
    # if affine: pre = "affine%d_" % (n_steps_a,) + pre
    if rand_init: pre = "randinit_" + pre
    pre = laypre + pre

    suf = "_%s" % (data_dir.replace("SemEval", "semeval"),) + laysuf + suf
    suf += "_lr%s_ceweight%s_klweight%s" % (str(lr).replace(".", ""), str(Decimal(str(cew)).normalize()),
                                            str(Decimal(str(klw)).normalize()) if klw <= 1 else str(int(klw)),)
    if l2w not in [0, 0.01]: suf += "_l2weight%s" % (str(l2w),)
    # if affine:
    #     if lra != 0.01: suf += "_lra%s" % (str(lra).replace(".", ""))
    #     if l2a not in [0, 0.01]: suf += "_l2a%s" % (str(l2a),)
    # if temp != 1.0: suf += "_%stemp" % (str(Decimal(str(temp)).normalize()),)
    if filter_query: suf += "_FilterQuery"
    suf += "_%dheads_%depoch" % (n_top_heads, n_steps,)
    if debug: suf += "_test"
    mainsuf = suf

    all_results = vector_optimization(datasetpaths, n_steps=n_steps, lr=lr, n_seeds=n_seeds, n_trials=n_trials, n_shots=n_shots, n_top_heads=n_top_heads,
                             edit_layer=edit_layer, prefixes=prefixes, separators=separators, isRand=is_rand)
    
    if semeval:
        all_benches, all_benchnames, all_benchrels = load_se_benchmarks()
        fv_strs = ["OG", "FV", "FFV"]
        for fi in range(len(fv_strs)):
            if rand_init and fi == 1: continue
            simpre, simsuf = laypre if fi == 0 else pre, "_%s%s" % (data_dir,"_test" if debug else "") if fi == 0 else mainsuf
            df_FVacosim = pd.DataFrame(torch.mean(torch.stack(all_results["all_fva_sims"][fi]), dim=0).float().numpy(), columns=all_benchnames, index=all_benchnames)
            df_FVacosim.to_csv(os.path.join(save_path_root,"%s/cosine_similarity/%s%s_activation_similarity%s.csv" % (model_name,simpre,fv_strs[fi],simsuf,)))
                
    with open(os.path.join(save_path_root, '%s/accuracy/FFV_accs/%s%stext_results%s.txt' \
                           % (model_name, pre, "rand_" if is_rand else "", mainsuf,)), 'w') as out_file:

        accs = []
        meanacc, stdacc = np.zeros(10), np.zeros(10)
        x = 0
        for dataset_name in datasetpaths:

            tot = len(all_results["ffv_zs_results"][dataset_name])
            fs_acc = [all_results["full_results"][dataset_name][i]['clean_topk'][topk-1][1] for i in range(tot)]
            fs_mean, fs_std = np.nanmean(fs_acc)*100, np.nanstd(fs_acc)*100

            # Shuffled-label
            sh_acc = [all_results["ffv_shuffled_results"][dataset_name][i]['clean_topk'][topk-1][1] for i in range(tot)]
            ffv_sh_acc = [all_results["ffv_shuffled_results"][dataset_name][i]['intervention_topk'][topk-1][1] for i in range(tot)]

            if np.isnan(np.nanmean(ffv_sh_acc)): continue
            sh_mean, ffvsh_mean = np.nanmean(sh_acc)*100, np.nanmean(ffv_sh_acc)*100
            sh_std, ffvsh_std = np.nanstd(sh_acc)*100, np.nanstd(ffv_sh_acc)*100

            print("Shuffled results:", sh_mean.round(3), '% +/-', sh_std.round(3), file=out_file)
            if not rand_init: 
                fv_sh_acc = [all_results["fv_shuffled_results"][dataset_name][i]['intervention_topk'][topk-1][1] for i in range(tot)]
                fvsh_mean, fvsh_std = np.nanmean(fv_sh_acc)*100, np.nanstd(fv_sh_acc)*100
                print("FV Shuffled results:", fvsh_mean.round(3), '% +/-', fvsh_std.round(3), file=out_file)
            else: fvsh_mean, fvsh_std = 0, 0
            print("FFV Shuffled results:", ffvsh_mean.round(3), '% +/-', ffvsh_std.round(3), file=out_file)

            # Zero-shot
            zs_acc = [all_results["ffv_zs_results"][dataset_name][i]['clean_topk'][topk-1][1] for i in range(tot)]
            ffv_zs_acc = [all_results["ffv_zs_results"][dataset_name][i]['intervention_topk'][topk-1][1] for i in range(tot)]

            if np.isnan(np.nanmean(ffv_zs_acc)): continue
            zs_mean, ffvz_mean = np.nanmean(zs_acc)*100, np.nanmean(ffv_zs_acc)*100
            zs_std, ffvz_std = np.nanstd(zs_acc)*100, np.nanstd(ffv_zs_acc)*100

            print(f"{dataset_name.title()}:", file=out_file)
            print("Full-Shot results:", fs_mean.round(3), '% +/-', fs_std.round(3), file=out_file)
            print("0-Shot results:", zs_mean.round(3), '% +/-', zs_std.round(3), file=out_file)
            if not rand_init: 
                fv_zs_acc = [all_results["fv_zs_results"][dataset_name][i]['intervention_topk'][topk-1][1] for i in range(tot)]
                fvz_mean, fvz_std = np.nanmean(fv_zs_acc)*100, np.nanstd(fv_zs_acc)*100
                print("FV 0-Shot results:", fvz_mean.round(3), '% +/-', fvz_std.round(3), file=out_file)
            else: fvz_mean, fvz_std = 0, 0
            print("FFV 0-Shot results:", ffvz_mean.round(3), '% +/-', ffvz_std.round(3), file=out_file)

            meanrow = [fs_mean, sh_mean, fvsh_mean, ffvsh_mean, zs_mean, fvz_mean, ffvz_mean]
            meanacc[:-3] += meanrow
            acc = [dataset_name] + meanrow
            sdrow = [fs_std, sh_std, fvsh_std, ffvsh_std, zs_std, fvz_std, ffvz_std]
            stdacc[:-3] += sdrow
            accsd = [dataset_name + " STD"] + sdrow

            if semeval:
                an_acc = [all_results["ana_results"][dataset_name][i]['clean_topk'][topk-1][1] for i in range(tot)]
                ffv_gold_acc = [all_results["ffv_ana_results"][dataset_name][i]['intervention_topk'][topk-1][1] for i in range(tot)]
                an_mean, ffvg_mean = np.nanmean(an_acc)*100, np.nanmean(ffv_gold_acc)*100
                an_std, ffvg_std = np.nanstd(an_acc)*100, np.nanstd(ffv_gold_acc)*100
                print("Analogy results:", an_mean.round(3), '% +/-', an_std.round(3), file=out_file)
                if not rand_init:
                    fv_gold_acc = [all_results["fv_ana_results"][dataset_name][i]['intervention_topk'][topk-1][1] for i in range(tot)]
                    fvg_mean, fvg_std = np.nanmean(fv_gold_acc)*100, np.nanstd(fv_gold_acc)*100
                    print("FV Analogy results:", fvg_mean.round(3), '% +/-', fvg_std.round(3), file=out_file)
                else: fvg_mean, fvg_std = 0,0
                print("FFV Analogy results:", ffvg_mean.round(3), '% +/-', ffvg_std.round(3), file=out_file)
                meanacc[-3:] += [an_mean, fvg_mean, ffvg_mean]
                stdacc[-3:] += [an_std, fvg_std, ffvg_std]
                acc += [an_mean, fvg_mean, ffvg_mean]
                accsd += [an_std, fvg_std, ffvg_std]
            accs += [acc, accsd]
        meanacc /= len(datasetnames)
        stdacc /= len(datasetnames)
        cols = ["Name", "Full-Shot", "Shuffled-label", "FV Shuffled", "FFV Shuffled", "0-Shot", "FV 0-Shot", "FFV 0-Shot"]
        print(f"Average:", file=out_file)
        for c in range(len(cols)-1):
            if cols[c+1][:2] == "FV" and rand_init: continue
            print(f"{cols[c+1]} results:", meanacc[c].round(3), '% +/-', stdacc[c].round(3), file=out_file)
        if semeval:
            print("Analogy results:", meanacc[-3].round(3), '% +/-', stdacc[-3].round(3), file=out_file)
            if not rand_init: print("FV Analogy results:", meanacc[-2].round(3), '% +/-', stdacc[-2].round(3), file=out_file)
            print("FFV Analogy results:", meanacc[-1].round(3), '% +/-', stdacc[-1].round(3), file=out_file)
            cols += ["Analogy", "FV Analogy", "FFV Analogy"]
        else:
            meanacc = meanacc[:-3]
            stdacc = stdacc[:-3]
        accs.append(["Average", *meanacc])
        accs.append(["Average STD", *stdacc])

        accdf = pd.DataFrame(accs, columns=cols)
        accdf.to_csv(os.path.join(save_path_root,"%s/accuracy/FFV_accs/%sFFV_accs%s.csv" % (model_name, pre, mainsuf,)), index=False)