from openai import OpenAI
import argparse
import json
import pickle
import os
import torch
from tqdm import tqdm
import numpy as np
from utils_exp import *
from sklearn.metrics.pairwise import cosine_similarity


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, choices=["mmlu", "medmcqa", "commonsenseqa", "hellaswag"], default="mmlu")
    parser.add_argument("--model", type=str, choices=["gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it", "Meta-Llama-3.1-8B-Instruct"], default="gemma-2-2b-it")
    parser.add_argument("--cache_dir", type=str, default="******")
    parser.add_argument("--device", type=str, choices=["0", "1", "2"], default="0")
    parser.add_argument("--result_dir", type=str, default="******")
    parser.add_argument("--apikey", type=str, choices=["0", "1", "2", "3"], default="0")
    return parser.parse_args()

def get_coherence_prompt(gens0, gens1):
    return f"""Below are two groups of AI-generated responses for answering the same multiple choice question. \n
    Each response contain some reasoning process and a final answer. \n
    The final answers given by each group of responses are the same. \n
    Take a look each group of responses, decide which group of responses is more coherent in each response's reasonings for choosing their final answer. \n\n 
    Here are the two groups: \n
    Group 1: {gens0}, \n\n
    Group 2: {gens1}. \n\n
    \nNow, decide which one is the more coherent group of responses. ONLY output [Group 1] or [Group 2], without any explanations."""

def get_llm_response(prompt, client):
    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=[
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": prompt},
        ],
        stream=False
    )
    return response.choices[0].message.content

def main():
    args = parse_args()
    apikey = "******"
    client = OpenAI(api_key=apikey, base_url="https://api.deepseek.com")
    torch.set_grad_enabled(False)
    
    device = f"cuda:{args.device}"
    result_dir = args.result_dir
    dataset = args.dataset
    model = args.model
    chosen_layer = 1
    if "2b" in model or "3.1" in model:
        chosen_layer = 2
    cot_path = "_cot"
    cache_dir = args.cache_dir
    n_ans = 12
    prompt_sample_combs = [(12,1), (6,2), (4,3), (3,4), (2,6), (1,12)]

    # load stuff
    fname_gen = result_dir + f"gen_{dataset}_{model}{cot_path}_12x12.json"
    with open(fname_gen, "r") as f:
        generations = json.load(f)
    fname_preds = result_dir + f"{dataset}_{model}{cot_path}/preds.pkl"
    with open(fname_preds, "rb") as f:
        preds_all = pickle.load(f)
    ds, _ = load_ds(dataset, cache_dir, cot=True)

    n_points = len(generations.keys())
    gts = np.zeros((n_points,))
    for i in range(n_points):
        if dataset == "mmlu":
            gts[i] = int(generations[str(i)]["answer"])
        elif dataset == "medmcqa":
            gts[i] = ds[i]['cop']
        elif dataset == "commonsenseqa":
            gts[i] = ord(ds[i]['answerKey']) - ord('A')
        elif dataset == "hellaswag":
            gts[i] = ds[i]["answerID"]

    sae_ids = []
    if "gemma-2-2b" in model:
        print("LOADING SAES FOR GEMMA 2B")
        sae_release = "gemma-scope-2b-pt-res-canonical"
        layers = [3, 7, 13, 20, 23]
        width = "16k"
        sae_ids = [f"layer_{item}/width_{width}/canonical" for item in layers]

    if "gemma-2-9b" in model:
        print("LOADING SAES FOR GEMMA 9B")
        sae_release = "gemma-scope-9b-it-res-canonical"
        layers = [9, 20, 31]
        width = "16k"
        sae_ids = [f"layer_{item}/width_{width}/canonical" for item in layers]

    if "gemma-2-27b" in model:
        print("LOADING SAES FOR GEMMA 27B")
        sae_release = "gemma-scope-27b-pt-res-canonical"
        layers = [10, 22, 34]
        width = "131k"
        sae_ids = [f"layer_{item}/width_{width}/canonical" for item in layers]

    if "3.1" in model:
        print("LOADING SAES FOR LLAMA 3.1 8B")
        sae_release = "llama_scope_lxr_8x"
        layers = [3, 8, 16, 24, 29]
        sae_ids = [f"l{item}r_8x" for item in layers]

    saes = []
    for sid in sae_ids:
        saes.append(SAE.from_pretrained(sae_release, sid, device=device)[0])

    print("LOADING ACTIVATIONS")
    model_acts_dir = result_dir + f"{dataset}_{model}{cot_path}"
    files = [f for f in os.listdir(model_acts_dir) if os.path.isfile(os.path.join(model_acts_dir, f))]
    files.sort()
    acts_01 = None
    for fn in files:
        if not "01" in fn:
            continue
        fname = model_acts_dir + "/" + fn
        this_tensor = torch.load(fname, weights_only=False, map_location=device)
        if acts_01 is None:
            acts_01 = this_tensor
        else:
            acts_01 = torch.vstack((acts_01, this_tensor))

    acts_25 = None
    for fn in files:
        if not "25" in fn:
            continue
        fname = model_acts_dir + "/" + fn
        this_tensor = torch.load(fname, weights_only=False, map_location=device).to(torch.float16)
        if acts_25 is None:
            acts_25 = this_tensor
        else:
            acts_25 = torch.vstack((acts_25, this_tensor))

    acts_611 = None
    for fn in files:
        if not "611" in fn:
            continue
        fname = model_acts_dir + "/" + fn
        this_tensor = torch.load(fname, weights_only=False, map_location=device).to(torch.float16)
        if acts_611 is None:
            acts_611 = this_tensor
        else:
            acts_611 = torch.vstack((acts_611, this_tensor))
    
    # load recorded results for getting answers
    embedding_scores = None
    print("\nloading entailment probabilities...")
    fname_ep = result_dir + f"{dataset}_{model}{cot_path}/embedding_scores.pkl" 
    with open(fname_ep, "rb") as f:
        embedding_scores = pickle.load(f)
    
    agreement_emb, agreement_raw, agreement_sae = 0, 0, 0
    agreement_len = 0

    for pid in range(len(prompt_sample_combs)):
        n_p, n_s = prompt_sample_combs[pid]
        
        # get flattened stuff
        generations_flattened = {}
        preds = np.zeros((n_points, n_ans))
        col_idx = 0
        for i in range(n_p):
            for j in range(n_s):
                preds[:, col_idx] = preds_all[:,i,j]
                col_idx += 1

        for idx in range(n_points):
            generations_flattened[idx] = []
            for i in range(n_p):
                for j in range(n_s):
                    generations_flattened[idx].append(generations[str(idx)][str(i)][j])

        print(f"\n\n=========== {n_p}, {n_s} ===========\n")
        # evaluate performance
        mask = np.any(preds[:, 1:] != preds[:, :-1], axis=1)
        diff_rows = np.where(mask)[0]
        fail_rows = np.unique(np.where(preds==-2)[0])
        diff_rows = np.array([r for r in diff_rows if r not in fail_rows])
        count_dict = {}
        # 6 vs 6 & 5 vs 7
        idxs66 = []
        for i in range(len(preds[diff_rows])):
            elements, counts = np.unique(preds[diff_rows][i], return_counts=True)
            if counts[0] == 6 and counts[1] == 6 or (5 in counts and 7 in counts):
                idxs66.append(diff_rows[i])
            if len(elements) not in count_dict.keys():
                count_dict[len(elements)] = {}
            else:
                counts.sort()
                if str(counts) not in count_dict[len(elements)]:
                    count_dict[len(elements)][str(counts)] = 1
                else:
                    count_dict[len(elements)][str(counts)] += 1
        for k in count_dict.keys():
            print(f"{k} answers", count_dict[k])
        
        # get flattened all raw activations
        all_answer_acts = torch.zeros((n_points, n_ans, len(saes), acts_01.shape[-1]), dtype=acts_01.dtype, device=device)
        act_idx = 0
        for i in range(n_p):
            act_idx_end = act_idx + n_s
            if i < 2:
                all_answer_acts[:, act_idx:act_idx_end] = acts_01[:, i, 0:n_s, :, :]
            elif i < 6:
                all_answer_acts[:, act_idx:act_idx_end] = acts_25[:, i-2, 0:n_s, :, :]
            else:
                all_answer_acts[:, act_idx:act_idx_end] = acts_611[:, i-6, 0:n_s, :, :]
            act_idx += n_s
        
        # evaluate
        preds_66 = np.zeros((len(idxs66)))
        groups_66 = np.zeros((len(idxs66)))
        preds_sae_66 = np.zeros((len(idxs66)))
        groups_sae_66 = np.zeros((len(idxs66)))
        preds_ep_66 = np.zeros((len(idxs66)))
        groups_ep_66 = np.zeros((len(idxs66)))
        model_coherence = np.zeros((len(idxs66)))
        # for id, idx in tqdm(enumerate(idxs66)):
        for id in tqdm(range(len(idxs66))):
            idx = int(idxs66[id])

            if id % 20 == 0:
                client = OpenAI(api_key=apikey, base_url="https://api.deepseek.com")

            # calculate similarity
            this_preds = preds[idx]
            uniq_preds, counts = np.unique(this_preds, return_counts=True)
            this_ep_dict = embedding_scores[n_p][idx]
            this_labels = [uniq_preds[0], uniq_preds[1]]
            idxs_l0 = np.where(this_preds == this_labels[0])[0]
            idxs_l1 = np.where(this_preds == this_labels[1])[0]
            flattened_answers_l0 = {}
            flattened_answers_l1 = {}
            for i, j in enumerate(idxs_l0):
                flattened_answers_l0[f"Answer {i}"] = generations_flattened[idx][j]
            for i, j in enumerate(idxs_l1):
                flattened_answers_l1[f"Answer {i}"] = generations_flattened[idx][j]

            # get label by model activation consistency
            this_answer_acts = all_answer_acts[idx]
            acts_l0 = this_answer_acts[idxs_l0, chosen_layer].to(torch.float16)
            acts_l1 = this_answer_acts[idxs_l1, chosen_layer].to(torch.float16)
            cons_l0 = np.mean(cosine_similarity(acts_l0.cpu().numpy()))
            cons_l1 = np.mean(cosine_similarity(acts_l1.cpu().numpy()))
            sae_acts_l0 = saes[chosen_layer].encode(acts_l0).to(torch.float16)
            sae_acts_l0 = saes[chosen_layer].encode(acts_l0).to(torch.float16)
            cons_sae_l0 = np.mean(cosine_similarity(sae_acts_l0.cpu().numpy()))
            cons_sae_l1 = np.mean(cosine_similarity(sae_acts_l0.cpu().numpy()))
            cons_ep_l0 = this_ep_dict[this_labels[0]]['ent_prob']
            cons_ep_l1 = this_ep_dict[this_labels[1]]['ent_prob']
            if cons_l0 >= cons_l1:
                preds_66[id] = this_labels[0]
                groups_66[id] = 1
            else:
                preds_66[id] = this_labels[1]
                groups_66[id] = 2
            if cons_sae_l0 >= cons_sae_l1:
                preds_sae_66[id] = this_labels[0]
                groups_sae_66[id] = 1
            else:
                preds_sae_66[id] = this_labels[1]
                groups_sae_66[id] = 2
            if cons_ep_l0 >= cons_ep_l1:
                preds_ep_66[id] = this_labels[0]
                groups_ep_66[id] = 1
            else:
                preds_ep_66[id] = this_labels[1]
                groups_ep_66[id] = 2
            prompt = get_coherence_prompt(flattened_answers_l0, flattened_answers_l1)
            llm_response = get_llm_response(prompt, client)
            if "1" in llm_response:
                model_coherence[id] = 1
            else:
                model_coherence[id] = 2
        print(np.sum(model_coherence == groups_ep_66) / len(groups_66), "EMB agreement rate with DeepSeek")
        print(np.sum(model_coherence == groups_66) / len(groups_66), "RAW agreement rate with DeepSeek")
        print(np.sum(model_coherence == groups_sae_66) / len(groups_66), "SAE agreement rate with DeepSeek")
        agreement_emb += np.sum(model_coherence == groups_ep_66)
        agreement_raw += np.sum(model_coherence == groups_66)
        agreement_sae += np.sum(model_coherence == groups_sae_66)
        agreement_len += len(groups_66)
    
    # aggregated results over prompt-sample configs
    print(agreement_emb / agreement_len, "EMB agreement rate with DeepSeek")
    print(agreement_raw / agreement_len, "RAW agreement rate with DeepSeek")
    print(agreement_sae / agreement_len, "SAE agreement rate with DeepSeek")
if __name__ == "__main__":
    main()

