import argparse
import json
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.nn.functional import softmax
from tqdm import tqdm
import numpy as np
from utils_exp import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, choices=["mmlu", "medmcqa", "commonsenseqa", "hellaswag"], default="mmlu")
    parser.add_argument("--model", type=str, choices=["gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it", "Meta-Llama-3.1-8B-Instruct"], default="gemma-2-2b-it")
    parser.add_argument("--cache_dir", type=str, default="******")
    parser.add_argument("--hf_token", type=str, default="******")
    parser.add_argument("--cot", action='store_true')
    parser.add_argument("--load_from_cpu", action='store_true', help="whether or not to first load model to cpu then move to cuda device")
    parser.add_argument("--device", type=str, choices=["0", "1", "2"], default="0")
    parser.add_argument("--result_dir", type=str, default="******")
    parser.add_argument("--n_ans", type=str, choices=["6", "12"], default="12")
    return parser.parse_args()

def get_embedding_scores(premise, hypothesis, btokenizer, bmodel, device):
    inputs = btokenizer(premise, hypothesis, padding=True, truncation=False, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = bmodel(**inputs)
        probs = softmax(outputs.logits, dim=-1)
    clean_gpus()
    return probs.detach().cpu().numpy()[:,0]

def get_avg_entailment_prob(idx, p_idxs, generations_flattened, btokenizer, bmodel, device):
    if len(p_idxs) == 1:
        return 1
    ent_score = 1
    entailment_mat = np.zeros((len(p_idxs), len(p_idxs)-1))
    for i1, p1 in enumerate(p_idxs):
        prems = []
        hypos = []
        for i2, p2 in enumerate(p_idxs):
            if p1 == p2:
                continue
            prems.append(generations_flattened[idx][p1])
            hypos.append(generations_flattened[idx][p2])
        this_scores = get_embedding_scores(prems, hypos, btokenizer, bmodel, device)
        entailment_mat[i1] = this_scores
    ent_score = np.round(np.mean(entailment_mat), 5)
    return ent_score

def main():
    args = parse_args()
    torch.set_grad_enabled(False)

    print("LOADING DATASET AND MODEL...")

    # load generations
    cot_path = "_cot" if args.cot else ""
    fname_gen = args.result_dir + f"gen_{args.dataset}_{args.model}{cot_path}_12x12.json"
    with open(fname_gen, "r") as f:
        generations = json.load(f)
    n_points = len(generations.keys())

    preds_all = None
    fname_preds = args.result_dir + f"{args.dataset}_{args.model}{cot_path}/preds.pkl"
    with open(fname_preds, "rb") as f:
        preds_all = pickle.load(f)

    # load embedding model
    device = f"cuda:{args.device}"
    if args.load_from_cpu:
        device = "cpu"
    bmname = "MoritzLaurer/bge-m3-zeroshot-v2.0"
    cache_dir = args.cache_dir
    btokenizer = AutoTokenizer.from_pretrained(bmname, cache_dir=cache_dir)
    bmodel = AutoModelForSequenceClassification.from_pretrained(bmname, device_map=device, cache_dir=cache_dir, torch_dtype=torch.float16)
    bmodel.eval()

    # start record embedding scores
    ep_scores_all_configs = {}
    n_ans = 12
    prompt_sample_combs = [(12,1), (6, 2), (4, 3), (3, 4), (2, 6), (1,12)]
    if args.n_ans == "6":
        n_ans = 6
        prompt_sample_combs = [(6,1), (3, 2), (2, 3), (1, 6)]
    # iterate over prompt x sample combinations
    for n_p, n_s in prompt_sample_combs:
        # get flattened generations and preds
        generations_flattened = {}
        preds = np.zeros((n_points, n_ans))
        col_idx = 0
        for i in range(n_p):
            for j in range(n_s):
                preds[:, col_idx] = preds_all[:,i,j]
                col_idx += 1
        for idx in range(n_points):
            generations_flattened[idx] = []
            for i in range(n_p):
                for j in range(n_s):
                    generations_flattened[idx].append(generations[str(idx)][str(i)][j])

        # get idxs where there are multiple predictions
        mask = np.any(preds[:, 1:] != preds[:, :-1], axis=1)
        diff_rows = np.where(mask)[0]

        # calculate entailment probability score
        ep_scores = {}
        print(f"Start calculating embedding score: entailment probability for {n_p} prompts * {n_s} samples")
        for i in tqdm(range(len(diff_rows))):
            idx = diff_rows[i]
            ep_scores[idx] = {}
            idx = int(idx)
            this_preds = preds[idx]
            uniq_preds, _ = np.unique(this_preds, return_counts=True)
            for p in uniq_preds:
                p_idxs = np.where(this_preds==p)[0]
                this_val = get_avg_entailment_prob(idx, p_idxs, generations_flattened, btokenizer, bmodel, device)
                this_scores = {"ent_prob": this_val, "p_idxs": p_idxs}
                ep_scores[idx][p] = this_scores
        ep_scores_all_configs[n_p] = ep_scores

    # save embedding scores
    fname_ep = args.result_dir + f"{args.dataset}_{args.model}{cot_path}/embedding_scores.pkl"
    if args.n_ans == "6":
        fname_ep = args.result_dir + f"{args.dataset}_{args.model}{cot_path}/embedding_scores_6.pkl"
    print("SAVING TO:", fname_ep)
    with open(fname_ep, "wb") as f:
        pickle.dump(ep_scores_all_configs, f)

if __name__ == "__main__":
    main()
