import argparse
import json
import pickle
import os
import torch
import numpy as np
from utils_exp import *
from sklearn.metrics.pairwise import cosine_similarity

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, choices=["mmlu", "medmcqa", "commonsenseqa", "hellaswag"], default="mmlu")
    parser.add_argument("--model", type=str, choices=["gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it", "Meta-Llama-3.1-8B-Instruct"], default="gemma-2-2b-it")
    parser.add_argument("--cache_dir", type=str, default="******")
    parser.add_argument("--hf_token", type=str, default="******")
    parser.add_argument("--cot", action='store_true')
    parser.add_argument("--load_from_cpu", action='store_true', help="whether or not to first load model to cpu then move to cuda device")
    parser.add_argument("--device", type=str, choices=["0", "1", "2"], default="0")
    parser.add_argument("--result_dir", type=str, default="******")
    parser.add_argument("--half_float", action='store_true')
    parser.add_argument("--center", action='store_true')
    parser.add_argument("--n_ans", type=str, choices=["6", "12"], default="12")
    return parser.parse_args()

def set_top_k(acts_1point, k=0.5):
    acts_1point = acts_1point.clone()
    for dim in range(len(acts_1point)):
        acts = acts_1point[dim]
        sorted_x, _ = torch.sort(acts[torch.where(acts>0)])
        num_non0 = torch.sum(acts>0)
        nk = int((1-k)*num_non0)
        thres = sorted_x[nk]
        acts_1point[dim][acts<thres] = 0
    return acts_1point

def calculate_cosine_sim(acts, k=0.0, center=False, num_prompts=12):
    # acts: n_points X n_prompts X SAE_dim
    # k = 0.1: only get the largest 10% non-zero elements
    # sim_all_dims = np.zeros((sae_len, len(acts), nprompts, nprompts))
    sim_all_dims = np.zeros((len(acts), num_prompts, num_prompts))
    center_tensor = torch.mean(acts, dim=(0, 1))
    for j in range(len(acts)):
        try:
            all_prompt_acts_this_point = acts[j, :, :] # n prompts X SAE_dim
            if center:
                all_prompt_acts_this_point = all_prompt_acts_this_point - center_tensor
            if k > 0:
                all_prompt_acts_this_point = set_top_k(all_prompt_acts_this_point, k)
            sim_all_dims[j] = cosine_similarity(all_prompt_acts_this_point.cpu().numpy()) 
        except:
            continue
    return sim_all_dims

def get_avg_sim(row, p_idxs, sims):
    if len(p_idxs) == 1:
        return 1
    this_row_sims = sims[row]
    res = 0
    num = 0
    for i in range(len(this_row_sims)):
        for j in range(i+1, len(this_row_sims)):
            if not (i in p_idxs and j in p_idxs):
                continue
            res += this_row_sims[i,j]
            num += 1
    return res/num

def evaluation_function(row, p_idxs, sims_ans, lamb_1=0.5, ep=False, n_ans=12):
    if ep:
        sim_val = sims_ans
    else:
        sim_val = (get_avg_sim(row, p_idxs, (sims_ans))+1)/2
    d = 11
    if n_ans == 6:
        d = 5
    vol_val = len(p_idxs)/d
    return lamb_1*sim_val + (1-lamb_1)*vol_val

def main():
    args = parse_args()
    torch.set_grad_enabled(False)

    ds, _ = load_ds(args.dataset, args.cache_dir, cot=args.cot)
    # load generations
    cot_path = "_cot" if args.cot else ""
    fname_gen = args.result_dir + f"gen_{args.dataset}_{args.model}{cot_path}_12x12.json"
    with open(fname_gen, "r") as f:
        generations = json.load(f)
    fname_preds = args.result_dir + f"{args.dataset}_{args.model}{cot_path}/preds.pkl"
    with open(fname_preds, "rb") as f:
        preds_all = pickle.load(f)
    n_points = len(generations.keys())

    # n_ans = n_p * n_s
    n_ans = 12
    prompt_sample_combs = [(12,1), (6,2), (4,3), (3,4), (2,6), (1,12)]
    if args.n_ans == "6":
        n_ans = 6
        prompt_sample_combs = [(6,1), (3,2), (2,3), (1,6)]
    
    # load SAEs
    device = f"cuda:{args.device}"
    if args.load_from_cpu:
        device = "cpu"
    dname = args.dataset
    sae_ids = []
    if "gemma-2-2b" in args.model:
        print("LOADING SAES FOR GEMMA 2B")
        sae_release = "gemma-scope-2b-pt-res-canonical"
        layers = [3, 7, 13, 20, 23]
        width = "16k"
        sae_ids = [f"layer_{item}/width_{width}/canonical" for item in layers]

    if "gemma-2-9b" in args.model:
        print("LOADING SAES FOR GEMMA 9B")
        sae_release = "gemma-scope-9b-it-res-canonical"
        layers = [9, 20, 31]
        width = "16k"
        sae_ids = [f"layer_{item}/width_{width}/canonical" for item in layers]

    if "gemma-2-27b" in args.model:
        print("LOADING SAES FOR GEMMA 27B")
        sae_release = "gemma-scope-27b-pt-res-canonical"
        layers = [10, 22, 34]
        width = "131k"
        sae_ids = [f"layer_{item}/width_{width}/canonical" for item in layers]

    if "3.1" in args.model:
        print("LOADING SAES FOR LLAMA 3.1 8B")
        sae_release = "llama_scope_lxr_8x"
        layers = [3, 8, 16, 24, 29]
        sae_ids = [f"l{item}r_8x" for item in layers]
    sae_dim = 16384 if "gemma" in args.model else 32768
    if "27" in args.model:
        sae_dim = 131072
    saes = []
    for sid in sae_ids:
        saes.append(SAE.from_pretrained(sae_release, sid, device=device)[0])

    # load recorded results for getting answers
    print("\nloading entailment probabilities...")
    fname_ep = args.result_dir + f"{args.dataset}_{args.model}{cot_path}/embedding_scores.pkl" ############
    if args.n_ans == "6":
        fname_ep = args.result_dir + f"{args.dataset}_{args.model}{cot_path}/embedding_scores_6.pkl"
    with open(fname_ep, "rb") as f:
        embedding_scores = pickle.load(f)
    
    print("\nloading model raw activations...")
    model_acts_dir = args.result_dir + f"{args.dataset}_{args.model}{cot_path}"
    files = [f for f in os.listdir(model_acts_dir) if os.path.isfile(os.path.join(model_acts_dir, f))]
    files.sort()
    acts_01 = None
    for fn in files:
        if not "01" in fn:
            continue
        fname = model_acts_dir + "/" + fn
        this_tensor = torch.load(fname, weights_only=False, map_location=device)
        if acts_01 is None:
            acts_01 = this_tensor
        else:
            acts_01 = torch.vstack((acts_01, this_tensor))

    acts_25 = None
    for fn in files:
        if not "25" in fn:
            continue
        fname = model_acts_dir + "/" + fn
        this_tensor = torch.load(fname, weights_only=False, map_location=device).to(torch.float16)
        if acts_25 is None:
            acts_25 = this_tensor
        else:
            acts_25 = torch.vstack((acts_25, this_tensor))

    acts_611 = None
    for fn in files:
        if not "611" in fn:
            continue
        fname = model_acts_dir + "/" + fn
        this_tensor = torch.load(fname, weights_only=False, map_location=device).to(torch.float16)
        if acts_611 is None:
            acts_611 = this_tensor
        else:
            acts_611 = torch.vstack((acts_611, this_tensor))
    
    # get ground truth
    gts = np.zeros((n_points,))
    for i in range(n_points):
        if dname == "mmlu":
            gts[i] = int(generations[str(i)]["answer"])
        elif dname == "medmcqa":
            gts[i] = ds[i]['cop']
        elif dname == "commonsenseqa":
            gts[i] = ord(ds[i]['answerKey']) - ord('A')
        elif dname == "hellaswag":
            gts[i] = ds[i]["answerID"]

    # start getting accuracies
    for pi in range(len(prompt_sample_combs)):
        n_p, n_s = prompt_sample_combs[pi]

        # get flattened generations and preds
        generations_flattened = {}
        preds = np.zeros((n_points, n_ans))
        col_idx = 0
        for i in range(n_p):
            for j in range(n_s):
                preds[:, col_idx] = preds_all[:,i,j]
                col_idx += 1

        for idx in range(n_points):
            generations_flattened[idx] = []
            for i in range(n_p):
                for j in range(n_s):
                    generations_flattened[idx].append(generations[str(idx)][str(i)][j])

        print(f"\n\n=========== {n_p}, {n_s} ===========\n")
        # evaluate performance
        mask = np.any(preds[:, 1:] != preds[:, :-1], axis=1)
        diff_rows = np.where(mask)[0]
        fail_rows = np.unique(np.where(preds==-2)[0])
        diff_rows = np.array([r for r in diff_rows if r not in fail_rows])
        print("accuracies of each prompt:")
        print([np.round(item, 4) for item in (preds == gts[:, None]).sum(axis=0) / n_points])
        print(f"\nnumber of points (out of {n_points}) with different answers:", len(diff_rows), f"{np.round(diff_rows.shape[0]/n_points, 4) * 100:.2f}%")
        print("\namount these points, the answer distributions are:")
        count_dict = {}
        for i in range(len(preds[diff_rows])):
            elements, counts = np.unique(preds[diff_rows][i], return_counts=True)
            if len(elements) not in count_dict.keys():
                count_dict[len(elements)] = {}
            else:
                counts.sort()
                if str(counts) not in count_dict[len(elements)]:
                    count_dict[len(elements)][str(counts)] = 1
                else:
                    count_dict[len(elements)][str(counts)] += 1
        for k in count_dict.keys():
            print(f"{k} answers", count_dict[k])

        # get flattened all raw activations
        all_answer_acts = torch.zeros((n_points, n_ans, len(saes), acts_01.shape[-1]), dtype=acts_01.dtype, device=device)
        act_idx = 0
        for i in range(n_p):
            act_idx_end = act_idx + n_s
            if i < 2:
                all_answer_acts[:, act_idx:act_idx_end] = acts_01[:, i, 0:n_s, :, :]
            elif i < 6:
                all_answer_acts[:, act_idx:act_idx_end] = acts_25[:, i-2, 0:n_s, :, :]
            else:
                all_answer_acts[:, act_idx:act_idx_end] = acts_611[:, i-6, 0:n_s, :, :]
            act_idx += n_s

        # accuracy: [(model_layer, lambda)] on a train set
        acc_config_emb = {}
        acc_config_raw = {}
        acc_config_sae = {}

        n_train_points = int(len(diff_rows) * 0.5)

        # for each model activation layer, calculate the scores
        for act_layer in range(len(saes)):

            sim_ans = calculate_cosine_sim(all_answer_acts[:, :, act_layer, :], k=0, center=args.center, num_prompts=n_ans)

            # get flattened sae activations
            all_answer_sae_acts = torch.zeros((n_points, n_ans, sae_dim), dtype=acts_01.dtype, device=device)
            for i in range(n_points):
                all_answer_sae_acts[i] = saes[act_layer].encode(all_answer_acts[i, :, act_layer, :])
            sim_ans_sae = calculate_cosine_sim(all_answer_sae_acts, k=0, center=args.center, num_prompts=n_ans)

            accs_raw = []
            accs_sae = []
            mode_accs = []
            accs_ep = []

            hyps = np.arange(0, 1, 0.04)
            for hid in range(len(hyps)):
                # group by answer
                lam = np.round(hyps[hid],4)
                rc_answers = np.zeros((n_train_points)) - 1
                rc_answers_sae = np.zeros((n_train_points))-1
                mode_answers = np.zeros((n_train_points))-1
                ep_answers = np.zeros((n_train_points))-1

                for i in range(n_train_points):
                    # idx = int(idx)
                    idx = int(diff_rows[i])
                    this_preds = preds[idx]
                    this_ep_dict = embedding_scores[n_p][idx]
                    uniq_preds, counts = np.unique(this_preds, return_counts=True)
                    # get majority prediction
                    mode_prediction = uniq_preds[np.argmax(counts)]
                    # if 5 vs 5
                    if len(np.unique(counts)) == 1:
                        mode_prediction = np.random.choice(uniq_preds)
                    mode_answers[i] = mode_prediction

                    # get the answer by evaluation
                    highest_val = -1
                    highest_val_sae = -1
                    rc_prediction = -1
                    rc_prediction_sae = -1

                    highest_ep = -1
                    ep_prediction = -1

                    for p in uniq_preds:
                        p_idxs = np.where(this_preds==p)[0]
                        this_val = evaluation_function(i, p_idxs, sim_ans, lamb_1=lam, n_ans=n_ans)
                        this_val_sae = evaluation_function(i, p_idxs, sim_ans_sae, lamb_1=lam, n_ans=n_ans)
                        this_ep = evaluation_function(i, p_idxs, this_ep_dict[p]['ent_prob'], lamb_1=lam, ep=True, n_ans=n_ans) ##########
                        if this_val > highest_val:
                            highest_val = this_val
                            rc_prediction = p
                        if this_val_sae > highest_val_sae:
                            highest_val_sae = this_val_sae
                            rc_prediction_sae = p
                        if this_ep > highest_ep
                            highest_ep = this_ep
                            ep_prediction = p
                    rc_answers[i] = rc_prediction
                    rc_answers_sae[i] = rc_prediction_sae
                    ep_answers[i] = ep_prediction

                gts_focus = gts[diff_rows[:n_train_points]]

                if hid == 0:
                    mode_accs.append(np.round(np.sum(rc_answers==gts_focus) / len(gts_focus), 4))
                accs_raw.append(np.round(np.sum(rc_answers==gts_focus) / len(gts_focus), 4))
                accs_sae.append(np.round(np.sum(rc_answers_sae==gts_focus) / len(gts_focus), 4))
                accs_ep.append(np.round(np.sum(ep_answers==gts_focus) / len(gts_focus), 4))
            
            print("model layer:", act_layer, layers[act_layer])

            # add to accuracies
            accs_raw_sorted = np.sort(np.unique(accs_raw, return_counts=False))[::-1]
            raw_a1 = float(np.round(accs_raw_sorted[0], 4))
            if raw_a1 not in acc_config_raw.keys():
                acc_config_raw[raw_a1] = []
            for lambda_val in hyps[np.where(accs_raw == accs_raw_sorted[0])]:
                acc_config_raw[raw_a1].append((act_layer, np.round(lambda_val, 4)))

            if float(np.round(np.max(accs_ep), 4)) not in acc_config_emb.keys():
                acc_config_emb[float(np.round(np.max(accs_ep), 4))] = []
            for lambda_val in hyps[np.where(accs_ep == np.max(accs_ep))]:
                acc_config_emb[float(np.round(np.max(accs_ep), 4))].append(np.round(lambda_val, 4))

            accs_sae_sorted = np.sort(np.unique(accs_sae, return_counts=False))[::-1]
            sae_a1 = float(np.round(accs_sae_sorted[0], 4))
            if sae_a1 not in acc_config_sae.keys():
                acc_config_sae[sae_a1] = []
            for lambda_val in hyps[np.where(accs_sae == accs_sae_sorted[0])]:
                acc_config_sae[sae_a1].append((act_layer, np.round(lambda_val, 4)))

        # evaluating on an eval set

        # baseline: single model accuracy
        gts_focus = gts[diff_rows[n_train_points:]]
        single_accs = []
        for i in range(n_ans):
            single_accs.append((preds[diff_rows[n_train_points:]][:, i] == gts_focus).sum()/len(gts_focus))
        
        # get best hyperparams
        rc_acc_emb = max(list(acc_config_emb.keys()))
        rc_lambdas_emb = []
        for item in np.unique(acc_config_emb[rc_acc_emb], return_counts=False):
            rc_lambdas_emb.append(item)
        print("best embedding lambda:", rc_lambdas_emb)

        accs_raw = sorted(list(acc_config_raw.keys()), reverse=True)
        rc_acc_raw = accs_raw[0]
        rc_hpps_raw = []
        rc_hpps_raw += acc_config_raw[rc_acc_raw]
        print("best raw hyperparams:", len(rc_hpps_raw), rc_hpps_raw)
        
        accs_sae = sorted(list(acc_config_sae.keys()), reverse=True)
        rc_acc_sae = accs_sae[0]
        rc_hpps_sae = []
        rc_hpps_sae += acc_config_sae[rc_acc_sae]
        print("best sae hyperparams:",len(rc_hpps_sae), rc_hpps_sae)

        accs_raw = []
        accs_sae = []
        mode_accs = []
        accs_ep = []
        n_test_points = len(diff_rows) - n_train_points
        # go over each test point and get result

        for raw_hpp_idx, hpps_raw in enumerate(rc_hpps_raw):
            rc_layer_raw = hpps_raw[0]
            rc_lambda_raw = hpps_raw[1]
            sim_ans = calculate_cosine_sim(all_answer_acts[:, :, rc_layer_raw, :], k=0, center=True, num_prompts=n_ans)

            rc_answers = np.zeros((n_test_points))-1
            mode_answers = np.zeros((n_test_points))-1
            for i in range(n_train_points, len(diff_rows)):
                idx = int(diff_rows[i])
                this_preds = preds[idx]
                this_ep_dict = embedding_scores[n_p][idx]
                uniq_preds, counts = np.unique(this_preds, return_counts=True)
                # get the answer by evaluation
                highest_val = -1
                highest_val_sae = -1
                highest_val_mode = -1
                rc_prediction = -1
                rc_prediction_sae = -1
                mode_prediction = -1
                for p in uniq_preds:
                    p_idxs = np.where(this_preds==p)[0]
                    this_val_mode = evaluation_function(i, p_idxs, sim_ans, lamb_1=0, n_ans=n_ans)
                    this_val = evaluation_function(i, p_idxs, sim_ans, lamb_1=rc_lambda_raw, n_ans=n_ans)
                    if this_val > highest_val:
                        highest_val = this_val
                        rc_prediction = p
                    if this_val_mode > highest_val_mode:##############
                        highest_val_mode = this_val_mode
                        mode_prediction = p
                rc_answers[i-n_train_points] = rc_prediction
                mode_answers[i-n_train_points] = mode_prediction
            mode_accs.append(np.round(np.sum(mode_answers==gts_focus) / len(gts_focus), 4))
            accs_raw.append(np.round(np.sum(rc_answers==gts_focus) / len(gts_focus), 4))
        
        for sae_hpp_idx, hpps_sae in enumerate(rc_hpps_sae):
            rc_layer_sae = hpps_sae[0]
            rc_lambda_sae = hpps_sae[1]
            # get flattened sae activations
            all_answer_sae_acts = torch.zeros((n_points, n_ans, sae_dim), dtype=acts_01.dtype, device=device)
            for i in range(n_points):
                all_answer_sae_acts[i] = saes[rc_layer_sae].encode(all_answer_acts[i, :, rc_layer_sae, :])
            sim_ans_sae = calculate_cosine_sim(all_answer_sae_acts, k=0, center=False, num_prompts=n_ans)
            
            rc_answers_sae = np.zeros((n_test_points))-1
            for i in range(n_train_points, len(diff_rows)):
                idx = int(diff_rows[i])
                this_preds = preds[idx]
                uniq_preds, counts = np.unique(this_preds, return_counts=True)
                # get the answer by evaluation
                highest_val_sae = -1
                rc_prediction_sae = -1
                for p in uniq_preds:
                    p_idxs = np.where(this_preds==p)[0]
                    this_val_sae = evaluation_function(i, p_idxs, sim_ans_sae, lamb_1=rc_lambda_sae, n_ans=n_ans)
                    
                    if this_val_sae > highest_val_sae:
                        highest_val_sae = this_val_sae
                        rc_prediction_sae = p

                rc_answers_sae[i-n_train_points] = rc_prediction_sae
            accs_sae.append(np.round(np.sum(rc_answers_sae==gts_focus) / len(gts_focus), 4))
        
        for emb_lam in rc_lambdas_emb:
            ep_answers = np.zeros((n_test_points))-1
            for i in range(n_train_points, len(diff_rows)):
                idx = int(diff_rows[i])
                this_preds = preds[idx]
                this_ep_dict = embedding_scores[n_p][idx]
                uniq_preds, counts = np.unique(this_preds, return_counts=True)
                # get the answer by evaluation
                highest_ep = -1
                ep_prediction = -1
                for p in uniq_preds:
                    p_idxs = np.where(this_preds==p)[0]
                    this_ep = evaluation_function(i, p_idxs, this_ep_dict[p]['ent_prob'], lamb_1=emb_lam, ep=True, n_ans=n_ans) ##########
                    if this_ep > highest_ep:
                        highest_ep = this_ep
                        ep_prediction = p
                ep_answers[i-n_train_points] = ep_prediction
            accs_ep.append(np.round(np.sum(ep_answers==gts_focus) / len(gts_focus), 4))

        print("AVG SINGLE MODEL     ACCURACY: ", np.round(np.mean(single_accs), 4), "\u00B1", np.round(np.std(single_accs), 4))
        print("AVG MODEL MODE       ACCURACY: ", np.round(np.max(mode_accs), 4))
        print("BEST EMB             ACCURACY: ", np.round(np.max(accs_ep), 4))
        print("BEST RAW ACTS        ACCURACY: ", np.round(np.max(accs_raw), 4), np.array(rc_hpps_raw)[np.where(accs_raw == np.max(accs_raw))])
        print("BEST SAE ACTS        ACCURACY: ", np.round(np.max(accs_sae), 4), np.array(rc_hpps_sae)[np.where(accs_sae == np.max(accs_sae))])

if __name__ == "__main__":
    main()

