import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import argparse
import json
import gc
from tqdm.auto import tqdm
from decimal import Decimal
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from scipy.stats import spearmanr
from transformers import get_scheduler
from sklearn.metrics import f1_score

# Include prompt creation helper functions
from utils.prompt_utils import load_dataset, load_dataset_comp, load_green, load_se_benchmarks
from utils.eval_utils import composite_eval, one_pair_no_intervention
from utils.opt_utils import one_pair_active, finetune_affine
from utils.model_utils import load_gpt_model_and_tokenizer, set_seed

torch.autograd.set_detect_anomaly(True)
torch.backends.cudnn.enabled = False

blank = {"input":"", "output":""}

def cfv_evaluation(fv_vectors, model, model_config, tokenizer, 
                      seed, aff=None, topk=0, 
                      prefixes=None, separators=None, edit_layer:int=-1,
                      bart=False, soft=False):
    
    lay = edit_layer
    distsuf = mainsuf.replace(affsuf,"").replace("_test","")
    # SemEval Benchmarks
    set_seed(seed)
    becfv_sim = None
    be_path = os.path.join(save_path_root,"%s/distribution/Bench_dists/%sBench_CFV_probs%s_seed%d.csv" % (model_name,laypre,distsuf,seed,))
    if os.path.exists(be_path):
        bedf = pd.read_csv(be_path)
        be_probs_matrix = torch.tensor(bedf.iloc[:, 2:].values)
    else:
        os.makedirs(os.path.join(save_path_root, '%s/distribution/Bench_dists/%s' % (model_name,laypre,)), exist_ok=True)
        be_probs = composite_eval(dataset=all_benches, fv_vectors=fv_vectors, edit_layer=lay,
                                            model=model, model_config=model_config,
                                            prefixes=prefixes, separators=separators, 
                                            tokenizer=tokenizer, intervention_prob_dict={},
                                            test_split="bench", dataname="SemEval")
        bedf = pd.DataFrame([pair.split(':') + be_probs[pair][-1].tolist() for pair in be_probs], columns=["A", "B", *datasetnames])
        bedf.to_csv(be_path, index=False)
        be_probs_matrix = torch.stack([be_probs[p][-1] for p in be_probs])
    
    be_pdist = be_probs_matrix.to(device=model.device, dtype=model.dtype)
    be_pdist = F.softmax(be_pdist, dim=1)
    if aff is not None: be_pdist = aff(be_pdist)

    be_cfvs = be_pdist @ torch.cat(fv_vectors)
    becfv_acts = []
    for p in tqdm(range(len(all_benches["bench"])), total=len(all_benches["bench"]), desc="SemEval (CFV)", leave=False, ncols=60):
        target_bench, benchname = all_benches["bench"][p], all_benchnames[p]
        be_cfv = be_cfvs[p % len(be_cfvs)]

        becfv_results = one_pair_active(data=[blank], fv_vector=be_cfv, edit_layer=lay,
                                    model=model, model_config=model_config, tokenizer=tokenizer,
                                    context=target_bench)
        becfv_acts.append(becfv_results["activation"][0])
    becfv_sim = torch.corrcoef(torch.stack(becfv_acts)).cpu()

    # Green
    gr_f1, gpred = None, None
    set_seed(seed)
    gr_path = os.path.join(save_path_root,"%s/distribution/Green_dists/%sGreen_CFV_probs%s_seed%d.csv" % (model_name,laypre,distsuf,seed,))
    if os.path.exists(gr_path):
        grdf = pd.read_csv(gr_path)
        gr_probs_matrix = torch.tensor(grdf.iloc[:, 2:].values)
    else:
        os.makedirs(os.path.join(save_path_root, '%s/distribution/Green_dists/%s' % (model_name,laypre,)), exist_ok=True)
        gr_probs = {}
        set_seed(seed)
        gr_probs = composite_eval(dataset=gcon, fv_vectors=fv_vectors, edit_layer=lay,
                                            model=model, model_config=model_config,
                                            prefixes=prefixes, separators=separators,
                                            tokenizer=tokenizer, intervention_prob_dict=gr_probs,
                                            test_split="green", dataname="Green (P.Dist)")
        # print("Green (Composite FV)")
        gr_probs_matrix = gr_probs if bart else torch.stack([gr_probs[p][-1] for p in gr_probs])
        grdf = pd.DataFrame([pair.split(':') + gr_probs[pair][-1].tolist() for pair in gr_probs], columns=["A", "B", *datasetnames])
        grdf.to_csv(gr_path, index=False)
    
    pdist = gr_probs_matrix.to(device=model.device, dtype=model.dtype)
    pdist = F.softmax(pdist, dim=1)
    if aff is not None: pdist = aff(pdist)
    
    gpred = np.zeros((2,len(gdatas)))
    cfvs = pdist @ torch.cat(fv_vectors)
    half = len(cfvs)

    for p in tqdm(range(len(gdatas)), total=len(gdatas), ncols=60, desc="Green (CFV)", leave=False):
        target_pair, pairname = gdatas[p], gpairnames[p]
        source_pair = gcon['green'][p]
        cfv = cfvs[p % half]
        # print(f"Pair:{pairname}")
        # Run on 1-shot target pair
        tp_reconstruction_results = one_pair_active(data=target_pair, fv_vector=cfv, edit_layer=lay,
                                model=model, model_config=model_config, tokenizer=tokenizer,
                                context=source_pair)
        # print(tp_reconstruction_results)
        
        gpred[:,p] = [tp_reconstruction_results['clean_topk'][topk-1][1], tp_reconstruction_results['intervention_topk'][topk-1][1]]
    # print(gbase, gpred)
    # gr_f1 = [[f1_score(gbase, gpred[0]), f1_score(gbase[:half], gpred[0, :half]), f1_score(gbase[half:], gpred[0, half:])],
    #             [f1_score(gbase, gpred[1]), f1_score(gbase[:half], gpred[1, :half]), f1_score(gbase[half:], gpred[1, half:])]]

    all_bats_og_acts, all_bats_acts = [], []
    bapred = np.zeros((2,len(all_bats_names)))
    for bai in tqdm(range(len(all_bats_names)), total=len(all_bats_names),
                    ncols=60, desc=f"BATS3.0 (CFV)"):
        bats_name = all_bats_names[bai]            
        set_seed(seed)
        bat = load_dataset_comp(bats_name, seed=seed)
        ba_path = os.path.join(save_path_root,"%s/distribution/BATS_dists/%s%s_BATS_CFV_probs%s_seed%d.csv" % (model_name,laypre,
                                                                                                                         bats_name,
                                                                                                                         distsuf,
                                                                                                                         seed,))
        if os.path.exists(ba_path):
            badf = pd.read_csv(ba_path)
            ba_probs_matrix = torch.tensor(badf.iloc[:, 2:].values)
        else:
            os.makedirs(os.path.join(save_path_root, '%s/distribution/BATS_dists/%s' % (model_name,laypre,)), exist_ok=True)
        
            ba_probs = {}
            set_seed(seed)
            ba_probs = composite_eval(dataset=bat, fv_vectors=fv_vectors, edit_layer=lay,
                                                model=model, model_config=model_config,
                                                prefixes=prefixes, separators=separators,
                                                tokenizer=tokenizer, intervention_prob_dict=ba_probs,
                                                test_split="query", dataname=f"BATS3.0 ({bats_name})")

            ba_probs_matrix = ba_probs if bart else torch.stack([ba_probs[p][-1] for p in ba_probs])
            badf = pd.DataFrame([pair.split(':') + ba_probs[pair][-1].tolist() for pair in ba_probs], columns=["A", "B", *datasetnames])
            badf.to_csv(ba_path, index=False)
        
        badist = ba_probs_matrix.to(device=model.device, dtype=model.dtype)
        badist = F.softmax(badist, dim=1)
        if aff is not None: badist = aff(badist)
        
        # print(ba_probs_matrix.shape, torch.cat(fv_vectors).shape)
        ba_cfvs = badist @ torch.cat(fv_vectors)
        baog_acts, ba_acts = [], []
        # half = len(ba_cfvs)

        for p in tqdm(range(len(bat['query'])), total=len(bat['query']),
                    ncols=60, desc=f"BATS3.0 ({bats_name} Eval)", leave=False):
            source_pair, target_pair = bat['query'][p], bat['target'][p]
            ba_cfv = ba_cfvs[p]
            # print(f"Pair:{pairname}")

            baog_results = one_pair_no_intervention(data=[target_pair], active_layer=lay,
                                    model=model, model_config=model_config, tokenizer=tokenizer,
                                    context=source_pair)
            baog_acts.append(baog_results["activation"][0])

            bats_results = one_pair_active(data=[target_pair], fv_vector=ba_cfvs[p], edit_layer=lay,
                                model=model, model_config=model_config, tokenizer=tokenizer,
                                context=source_pair)
            ba_acts.append(bats_results["activation"][0])

            bapred[:,bai] += [bats_results['clean_topk'][topk-1][1], bats_results['intervention_topk'][topk-1][1]]
        bapred[:,bai] /= len(bat['query'])
        all_bats_og_acts.append(torch.mean(torch.stack(baog_acts), dim=0))
        all_bats_acts.append(torch.mean(torch.stack(ba_acts), dim=0))
        # print(babase, bapred)
        # ba_f1 = [[f1_score(babase, bapred[0]), f1_score(babase[:half], bapred[0, :half]), f1_score(babase[half:], bapred[0, half:])],
        #          [f1_score(babase, bapred[1]), f1_score(babase[:half], bapred[1, :half]), f1_score(babase[half:], bapred[1, half:])]]
    og_bats_sim = torch.corrcoef(torch.stack(all_bats_og_acts)).cpu()
    bats_sim = torch.corrcoef(torch.stack(all_bats_acts)).cpu()
    return becfv_sim, gpred, bapred, og_bats_sim, bats_sim

def vector_optimization(datasetnames, n_steps:int=5, lr:float=0.05, n_seeds:int=5, n_trials:int=100, n_shots:int=10, n_top_heads:int=10,
                         edit_layer:int=-1, prefixes=None, separators=None, 
                         isRand=False):
    """
    Computes and evaluates a function vector reconstruction which matches its output vocabulary distribution.
    
    Parameters:
    n_steps: number of optimization steps
    lr: adam learning rate
    n_seeds: number of seeds to run
    n_trials: number of prompts to compute task-conditioned mean head activations over
    n_shots: number of shots for task-conditioned mean prompts
    restrict_vocab_list: list of ints determining how many vocab words to match. Defaults to 100 & full-vocab (which is 50400 for GPT-J)
    return_vecs: whether to return the function vectors and their corresponding vocab-optimized reconstruction vectors

    Returns:
    fv_zs_results: FV results
    cfv_zs_results: 
    kl_divs: kl divergences between the distribution of the FV and its reconstruction
    fvs: (optional) the function vectors used
    vns: (optional) the vocab-optimized reconstruction vectors
    """

    seeds = []
    # angr_f1s, fvgr_f1s, cfvgr_f1s = [], [], []
    angr_results, cfvgr_results = [], []
    anba_results, cfvba_results = [], []
    cfv_sims = []
    ogb_sims, cfvb_sims = [], []
    lay = edit_layer-1
    # print(model_config['layer_hook_names'], model_config['attn_hook_names'])
    teachpre = f'../results/Meta-Llama-3.1-8B-Instruct/cosine_similarity/Layers/Refined/'

    for i in range(n_seeds):
        fvs = []
        typedatasets, datasets = {}, {}
        filter_sets = {k:None for k in datasetnames}
        seed = np.random.randint(100000) if isRand else i+1
        print(f"seed:{seed}")
        # seeds.append(seed)
        fvsuf = '_%dheads_%dshot_%dtrial_%dOptPairs_seed%d%s' % (n_top_heads, n_shots,
                                                                 n_trials, optlim, seed,
                                                                 "_FiltQuery" if filter_query else "")
        
        print("Initialization...")

        cfvsuf = fvsuf.replace("shot_", "shot_lr%s_" % str(lr).replace(".", "")).replace("OptPairs_","OptPairs_%sCE_%sKL_%depoch_" % (str(Decimal(str(cew)).normalize()),
                                                                                                                                        str(Decimal(str(klw)).normalize()),
                                                                                                                                        n_steps,))
        if l2w not in [0, 0.01]: cfvsuf = cfvsuf.replace("KL_", "KL_%sL2_" % (str(l2w)),)
        cfv_path_suf = f'Finetuned_FV{cfvsuf}.pt'
        cfv_paths = []
        for di in tqdm(range(len(datasetnames)), total=len(datasetnames), ncols=50, desc="FFV Setup", leave=False):
            datafold = datasetfolds[datasetnames[di]]
            cfv_path_pre = f'{save_path_root}/{model_name}/Finetuned_FVs/{datafold}/'
            cfv_path = f'{cfv_path_pre}{datasetnames[di]}_{cfv_path_suf}'
            if datafold == "SemEval":
                sereps = [[f"{n_steps}epoch", f"{semeval_epochs[model_name]}epoch"],
                            [f"{optlim}OptPairs", "10OptPairs"]]
                for rep in sereps:
                    cfv_path = cfv_path.replace(rep[0], rep[1])
            cfv_paths.append(cfv_path)
        finetuned_fvs = [torch.load(cfv_path, weights_only=True) for cfv_path in cfv_paths]
        fvn = len(finetuned_fvs)
        if affine:
            aff = nn.Linear(fvn, fvn, device=model.device, dtype=model.dtype)
            opta = torch.optim.AdamW(aff.parameters(), lr=lra, weight_decay=l2a)
            num_training_steps_a = n_steps * 1000
            sched_a = torch.optim.lr_scheduler.CosineAnnealingLR(opta, num_training_steps_a)

            print("Optimization...")
            torch.set_grad_enabled(True)
            reps = [["Finetuned_FV_",""],
                    ["lr%s_" % str(lr).replace(".", ""), "lr%s_" % str(lra).replace(".", "")],
                    ["_%depoch_" % (n_steps,),"_%depoch_" % (n_steps_a,)]]
            aff_path_suf = cfv_path_suf
            for rep in reps: aff_path_suf = aff_path_suf.replace(rep[0], rep[1])
            for p in model.parameters(): p.requires_grad = False
            try:
                for cfvu in finetuned_fvs: cfvu = cfvu.detach()
            except: pass
            aff_path = f'{save_path_root}/{model_name}/Finetuned_FVs/{subfeat}/Affine_{aff_path_suf}'
            try:
                if updatea: raise StopIteration("Transformation needs to be updated! Optimizing...")
                aff.load_state_dict(torch.load(aff_path)['model_state_dict'])
            except Exception as e:
                print("Optimizing Affine...")
                os.makedirs(f'{save_path_root}/{model_name}/Finetuned_FVs/{subfeat}', exist_ok=True)
                a_epochs = n_steps_a
                ack_path = aff_path.replace("_%depoch_" % (n_steps_a,),"_%depoch_" % (ackpt,))
                if (ackpt > 0 and a_epochs - ackpt > 0) and os.path.exists(ack_path):
                    try:
                        checkpoint = torch.load(ack_path)
                        aff.load_state_dict(checkpoint['model_state_dict'])
                        opta.load_state_dict(checkpoint['optimizer_state_dict'])
                        sched_a.load_state_dict(checkpoint['scheduler_state_dict'])
                        a_epochs -= ackpt
                    except Exception as e:
                        print("Converting checkpoint to state_dict...")
                        try:
                            check = torch.load(ack_path, weights_only=False)
                            torch.save(check.state_dict(), ack_path)
                            aff.load_state_dict(check.state_dict())
                            a_epochs -= ackpt
                        except Exception as e:
                            print(e)
                            print("Can't extract from checkpoint! Starting from scratch...")
                for di in tqdm(range(len(datasetnames)), total=len(datasetnames), ncols=50, desc="Dataset Setup", leave=False):
                    dataset_name = datasetnames[di]
                    dalim = 10 if datasetfolds[datasetnames[di]] == "SemEval" else optlim
                    datype = datasetfolds[datasetnames[di]] != "SemEval"
                    set_seed(seed)
                    dataset = load_dataset(dataset_name, [dataset_name], seed=seed, istype=datype,
                                            optlim=dalim, root_data_dir=root_data_dir, gold=datype)
                    datasets[dataset_name] = dataset
                for af in aff.parameters(): af.requires_grad = True

                os.makedirs(f'{save_path_root}/{model_name}/distribution/CFV_dists/{laypre}', exist_ok=True)
                postpath = f'{save_path_root}/{model_name}/distribution/CFV_dists/{laypre}/ProbDist{cfvsuf}.csv'
                splitpath = f'{save_path_root}/{model_name}/distribution/CFV_dists/DataSplit_seed{seed}.csv'
                set_seed(seed)
                aff = finetune_affine(datasets=datasets, fv_vectors=finetuned_fvs,
                                      aff=aff, edit_layer=lay, n_steps=a_epochs,
                                      model=model, model_config=model_config,
                                      tokenizer=tokenizer, optimizer=opta, lr_scheduler=sched_a,
                                      weight=cew, split="opt",
                                      postpath=postpath, splitpath=splitpath)
                torch.save({'model_state_dict': aff.state_dict(),
                            'optimizer_state_dict': opta.state_dict(),
                            'scheduler_state_dict': sched_a.state_dict()
                            }, aff_path)
            torch.set_grad_enabled(False)
        else: aff = None
        
        print("CFV Evaluation...")
        cfv_sim, cfv_gpred, cfv_bapred, og_bat_sim, cfvbat_sim = \
            cfv_evaluation(finetuned_fvs, model, model_config, tokenizer,
                            seed, aff, prefixes=prefixes, separators=separators,
                            edit_layer=lay, bart=bart, soft=soft, topk=topk)
        
        angr_results.append(cfv_gpred[0])
        cfvgr_results.append(cfv_gpred[1])
        anba_results.append(cfv_bapred[0])
        cfvba_results.append(cfv_bapred[1])
        ogb_sims.append(og_bat_sim)
        cfvb_sims.append(cfvbat_sim)
        cfv_sims.append(cfv_sim)
        del aff
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Averages: Green Analogy={np.mean(angr_results[-1])}, Green CFV={np.mean(cfvgr_results[-1])}")
        print(f"Averages: BATS Analogy={np.mean(anba_results[-1])}, BATS CFV={np.mean(cfvba_results[-1])}")

    all_fvb_sims = [ogb_sims, cfvb_sims]

    return {"cfv_sims":cfv_sims, "angr_results":angr_results, "cfvgr_results":cfvgr_results,
            "anba_results":anba_results, "cfvba_results":cfvba_results, "all_fvb_sims":all_fvb_sims
            }
    
if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', help='Name of model to be loaded', type=str, required=False, default='EleutherAI/gpt-j-6b')
    parser.add_argument('--root_data_dir', help='Root directory of data files', type=str, required=False, default='../dataset_files')
    parser.add_argument('--save_path_root', help='File path to save mean activations to', type=str, required=False, default='../results')
    parser.add_argument('--n_top_heads', help='Number of attenion head outputs used to compute function vector', required=False, type=int, default=10)
    parser.add_argument('--edit_layer', help='Layer to apply optimized FV on', required=False, type=int, default=-1)
    parser.add_argument('--n_seeds', help='Number of seeds', type=int, required=False, default=5)
    parser.add_argument('--n_trials', help='Number of trials to use for computing task-conditioned mean head activations', type=int, required=False, default=100)
    parser.add_argument('--n_shots', help='Number of shots to use for prompts when computing task-conditioned mean head activations', type=int, required=False, default=5)
    parser.add_argument('--lr', help="Learning Rate for Adam Optimizer", type=float, required=False, default=0.05)
    parser.add_argument('--n_steps', help="Number of epochs", type=int, required=False, default=5)
    parser.add_argument('--cew', help="CE weight", type=float, required=False, default=1)
    parser.add_argument('--klw', help="KL weight", type=float, required=False, default=0)
    parser.add_argument('--l2w', help="L2 weight", type=float, required=False, default=0.01)
    parser.add_argument('--lra', help="Learning Rate for affine optimizer", type=float, required=False, default=0.001)
    parser.add_argument('--l2a', help="L2 weight for affine", type=float, required=False, default=0.01)
    parser.add_argument('--is_rand', help="Determines if seeds should be randomized", action='store_true')
    parser.add_argument('--prompt', help='Prompt template prefixes to be used', type=str, required=False, default="QA")
    parser.add_argument('--pure', help='Whether to use prompt FVs for optimization, otherwise use default', action='store_true')
    parser.add_argument('--soft', help='Uses softmax for posterior, otherwise log-softmax', action='store_true')
    parser.add_argument('--batch', help='Batch size for training', type=int, required=False, default=1)
    parser.add_argument('--is_uni', help='Whether to use universal FV', action='store_true')
    parser.add_argument('--topk', help='Returns accuracy for top-k (must be less than 5)', type=int, required=False, default=0)
    parser.add_argument('--optlim', help='Determines number of pairs for optimization', type=int, required=False, default=10)
    parser.add_argument('--ckpt', help='Train from checkpoint', type=int, required=False, default=0)
    parser.add_argument('--filter_query', help="Filter for initial FV", action='store_true')
    parser.add_argument('--bart', help="Uses BART probs", action='store_true')
    parser.add_argument('--update', help="Updates cfv", action='store_true')
    parser.add_argument('--isGold', help="Distinguishes Golden pairs from data", action='store_true')
    parser.add_argument('--affine', help="Uses affine transform", action='store_true')
    parser.add_argument('--ackpt', help='Train from affine checkpoint', type=int, required=False, default=0)
    parser.add_argument('--n_steps_a', help="Number of epochs for affine", type=int, required=False, default=10)
    parser.add_argument('--updatea', help="Updates affine transformation", action='store_true')
    parser.add_argument('--debug', help="Sets to debug mode", action='store_true')

    args = parser.parse_args()
    print(args)
    
    # Gather inputs
    modname = args.model_name
    model_name = modname.split('/')[-1]
    root_data_dir = args.root_data_dir
    save_path_root = args.save_path_root
    n_seeds = args.n_seeds
    n_trials = args.n_trials
    n_shots = args.n_shots
    lr = args.lr
    cew, klw = args.cew, args.klw
    l2w = args.l2w
    if l2w == 0: l2w = 0.01 # The default
    lra, l2a = args.lra, args.l2a
    if l2a == 0: l2a = 0.01 # The default
    n_steps, n_steps_a = args.n_steps, args.n_steps_a
    n_top_heads = args.n_top_heads
    edit_layer = args.edit_layer
    is_rand = args.is_rand
    prompt = args.prompt
    pure = args.pure
    soft = args.soft
    update, updatea = args.update, args.updatea
    ckpt, ackpt = args.ckpt, args.ackpt
    filter_query = args.filter_query
    is_uni = args.is_uni
    batch = args.batch
    optlim = 0
    affine = args.affine
    bart = args.bart
    isGold = False
    debug = args.debug
    topk = args.topk
    if debug: n_seeds = 1
    # assert(0 <= topk <= 5, "Top-K only goes up to 5! Try another topk")
    
    # Load Model & Tokenizer
    torch.set_grad_enabled(False)
    model, tokenizer, model_config = load_gpt_model_and_tokenizer(modname)
    if tokenizer.pad_token is None and 'llama' in model_name.lower():
        tokenizer.pad_token = tokenizer.eos_token
    # print(model_config['n_layers'], model_config['n_heads'])

    assert(prompt in ["QA", "AB", "Analogy"])
    prefixes = {"input":"Q:", "output":"A:", "instructions":""}
    separators = {"input":"\n", "output":"\n\n", "instructions":""}
    if prompt == "Analogy":
        prefixes = {"input":"", "output":":", "instructions":""}
        separators["input"] = ""
    elif prompt == "AB":
        prefixes = {"input":"A:", "output":"B:", "instructions":""}
    
    from itertools import groupby
    data_dir = "comp"
    datafs = ["abstractive", "SemEval", "Google", "MSR"]
    datasetfolds, alldatasetnames = {}, []
    for dataf in datafs:
        for dat in sorted(os.listdir(os.path.join(root_data_dir, dataf))):
            dname = os.path.splitext(dat)[0]
            alldatasetnames += [dname]
            datasetfolds[dname] = dataf
    datasetnames = alldatasetnames
    args.datasetnames = datasetnames

    # Set up result outline
    import pandas as pd
    presub = f'{prompt}_' if prompt in [f"AB", "Analogy"] else ""
    subfeat = presub + data_dir
    laysuf = "_layer%d" % (edit_layer,) if edit_layer != -1 else ""
    laypre = f'{edit_layer}/{subfeat}/' if edit_layer != -1 else f'All_Layers/{subfeat}/'
    
    os.makedirs(os.path.join(save_path_root, '%s/accuracy/BATS_accs/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/accuracy/Green_accs/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/distribution/CFV_dists/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/distribution/Bench_dists/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/distribution/BATS_dists/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/distribution/Green_dists/%s' % (model_name,laypre,)), exist_ok=True)
    os.makedirs(os.path.join(save_path_root, '%s/cosine_similarity/%s' % (model_name,laypre,)), exist_ok=True)
    pre, suf = "", ""
    if bart: pre = "BART_" + pre
    if affine: pre = "affine%d_" % (n_steps_a,) + pre
    pre = laypre + pre
    suf = f"_{data_dir}" + laysuf + suf
    suf += "_lr%s_ceweight%s_klweight%s" % (str(lr).replace(".", ""), str(Decimal(str(cew)).normalize()),
                                            str(Decimal(str(klw)).normalize()) if klw <= 1 else str(int(klw)),)
    if l2w not in [0, 0.01]: suf += "_l2weight%s" % (str(l2w),)
    affsuf = ""
    if affine:
        if lra != 0.01: affsuf += "_lra%s" % (str(lra).replace(".", ""))
        if l2a not in [0, 0.01]: affsuf += "_l2a%s" % (str(l2a),)
        suf += affsuf
    if filter_query: suf += "_FilterQuery"
    
    suf += "_%dheads_%depoch" % (n_top_heads, n_steps,)
    if debug: suf += "_test"
    mainsuf = suf

    semeval_epochs = {'gpt2-medium':25, 'gpt-j-6b':10,
                      'Llama-2-7b-chat-hf':5,
                      'Meta-Llama-3.1-8B-Instruct':5}
    
    all_benches, all_benchnames, all_benchrels = load_se_benchmarks()
    all_bats_names = [os.path.splitext(bn)[0] for bn in sorted(os.listdir(os.path.join(root_data_dir, "BATS")))]
    gcon, source_gpairnames, gpairnames, gdatas, unorms = load_green()

    all_results = vector_optimization(datasetnames, n_steps=n_steps, lr=lr, n_seeds=n_seeds, n_trials=n_trials, n_shots=n_shots, n_top_heads=n_top_heads,
                             edit_layer=edit_layer, prefixes=prefixes, separators=separators, isRand=is_rand)
    
    # SemEval Activation File
    df_FVccosim = pd.DataFrame(torch.mean(torch.stack(all_results["cfv_sims"]), dim=0).float().numpy(), columns=all_benchnames, index=all_benchnames)
    df_FVccosim.to_csv(os.path.join(save_path_root,"%s/cosine_similarity/%sCFV_activation_similarity%s.csv" % (model_name,pre,mainsuf,)))

    # Green Evaluation Files
    os.makedirs(os.path.join(save_path_root, '%s/accuracy/Green_accs/%s' % (model_name,laypre,)), exist_ok=True)
    angr_means, cfvgr_means = np.mean(np.array(all_results["angr_results"]), axis=0), np.mean(np.array(all_results["cfvgr_results"]), axis=0)
    angr_stds, cfvgr_stds = np.std(np.array(all_results["angr_results"]), axis=0), np.std(np.array(all_results["cfvgr_results"]), axis=0)
    gr_accs = [[source_gpairnames[o], gpairnames[o], np.mean(angr_means[o]), np.mean(cfvgr_means[o]), np.mean(angr_stds[o]), np.mean(cfvgr_stds[o])] \
               for o in range(len(gpairnames))]
    half = len(gpairnames) // 2
    gr_accs.append(["Average", "Average", np.mean(angr_means), np.mean(cfvgr_means), np.mean(angr_stds), np.mean(cfvgr_stds)])
    gr_accs.append(["Average", "Far Average", np.mean(angr_means[:half]), np.mean(cfvgr_means[:half]), np.mean(angr_stds[:half]), np.mean(cfvgr_stds[:half])])
    gr_accs.append(["Average", "Near Average", np.mean(angr_means[half:]), np.mean(cfvgr_means[half:]), np.mean(angr_stds[half:]), np.mean(cfvgr_stds[half:])])
    graccdf = pd.DataFrame(gr_accs, columns=["Source Pair", "Target Pair", "One-Shot", "Composite FV", "One-Shot STD", "Composite FV STD"])
    graccdf.to_csv(os.path.join(save_path_root,"%s/accuracy/Green_accs/%sGreen_accs%s.csv" % (model_name, pre, mainsuf,)), index=False)

    # BATS Evaluation Files
    fv_strs = ["OG", "CFV"]

    for fi in range(len(fv_strs)):
        simpre, simsuf = laypre if fi == 0 else pre, "_comp%s" % ("_test" if debug else "") if fi == 0 else mainsuf
        df_FVbcosim = pd.DataFrame(torch.mean(torch.stack(all_results["all_fvb_sims"][fi]), dim=0).float().numpy(), columns=all_bats_names, index=all_bats_names)
        df_FVbcosim.to_csv(os.path.join(save_path_root,"%s/cosine_similarity/%s%s_BATS_similarity%s.csv" % (model_name,simpre,fv_strs[fi],simsuf,)))

    os.makedirs(os.path.join(save_path_root, '%s/accuracy/BATS_accs/%s' % (model_name,laypre,)), exist_ok=True)
    anba_means, cfvba_means = np.mean(np.array(all_results["anba_results"]), axis=0), np.mean(np.array(all_results["cfvba_results"]), axis=0)
    anba_stds, cfvba_stds = np.std(np.array(all_results["anba_results"]), axis=0), np.std(np.array(all_results["cfvba_results"]), axis=0)
    ba_accs = [[all_bats_names[o], np.mean(anba_means[o]), np.mean(cfvba_means[o]), np.mean(anba_stds[o]), np.mean(cfvba_stds[o])] \
                for o in range(len(all_bats_names))]
    ba_accs.append(["Average", np.mean(anba_means), np.mean(cfvba_means), np.mean(anba_stds), np.mean(cfvba_stds)])
    baaccdf = pd.DataFrame(ba_accs, columns=["Task", "One-Shot", "Composite FV", "One-Shot STD", "Composite FV STD"])
    baaccdf.to_csv(os.path.join(save_path_root,"%s/accuracy/BATS_accs/%sBATS_accs%s.csv" % (model_name, pre, mainsuf,)), index=False)

