import argparse
import json
import pickle
import torch
from tqdm import tqdm
import numpy as np
from huggingface_hub import login
import os
from utils_exp import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, choices=["mmlu", "medmcqa", "commonsenseqa", "hellaswag"], default="mmlu")
    parser.add_argument("--model", type=str, choices=["gemma-2-2b-it", "gemma-2-9b-it", "gemma-2-27b-it", "Meta-Llama-3.1-8B-Instruct"], default="gemma-2-2b-it")
    parser.add_argument("--cache_dir", type=str, default="******")
    parser.add_argument("--hf_token", type=str, default="******")
    parser.add_argument("--cot", action='store_true')
    parser.add_argument("--load_from_cpu", action='store_true', help="whether or not to first load model to cpu then move to cuda device")
    parser.add_argument("--device", type=str, choices=["0", "1", "2"], default="0")
    parser.add_argument("--result_dir", type=str, default="******")
    parser.add_argument("--start_idx_save", type=str, default="0")
    parser.add_argument("--n_points_sae", type=str, default="0")
    parser.add_argument("--offset", type=str, default="0")
    return parser.parse_args()


def main():
    args = parse_args()
    torch.set_grad_enabled(False)

    # load dataset and model
    login(args.hf_token)
    print("LOADING DATASET AND MODEL...")
    if args.cot:
        print("COT PROMPTS VERSION...")
    device = "cuda:0"

    if args.load_from_cpu:
        device = "cpu"
    dtype = "float16"
    if "27" in args.model:
        dtype = "float16"
    mname = args.model
    if "Llama" in mname:
        mname = 'meta-llama/Llama-3.1-8B-Instruct'
    model = load_model(mname, args.cache_dir, device=device, dtype=dtype)
    ds, ds_util = load_ds(args.dataset, args.cache_dir, cot=args.cot)
    if args.load_from_cpu:
        device = f"cuda:{args.device}"
        model.to(device)
    print("DONE LOADING MODEL")

    # load results
    dname = args.dataset
    if device == "cuda":
        device = device + f":{args.device}"
    model.eval()
    save_dir = args.result_dir
    fname_gen = ""
    if args.cot:
        fname_gen = save_dir + f"gen_{dname}_{args.model}_cot_12x12.json"
    else:
        fname_gen = save_dir + f"gen_{dname}_{args.model}_12x12.json"
    with open(fname_gen, "r") as f:
        generations = json.load(f)

    # generate and save activations
    layers = []
    if "gemma-2-2b" in args.model:
        layers = [3, 7, 13, 20, 23]
    if "gemma-2-9b" in args.model:
        layers = [9, 20, 31]
    if "gemma-2-27b" in args.model:
        layers = [10, 22, 34]
    if "3.1" in args.model:
        layers = [3, 8, 16, 24, 29]
    sae_names = [f"blocks.{l}.hook_resid_post" for l in layers]
    # create prediction object
    n_points = len(generations.keys())
    # -2 indicates there's no answer recorded. -1 is None of the above. 0-4 A-E
    preds = np.zeros((n_points, 12, 12)) - 2
    model_hidden_size = model.cfg.d_model
    raw_acts_01 = torch.zeros((n_points, 2, 12, len(sae_names), model_hidden_size), device=device, dtype=torch.float16)
    raw_acts_25 = torch.zeros((n_points, 4, 4, len(sae_names), model_hidden_size), device=device, dtype=torch.float16)
    raw_acts_611 = torch.zeros((n_points, 6, 1, len(sae_names), model_hidden_size), device=device, dtype=torch.float16)

    n_samples = [12, 6, 4, 3, 2, 2, 1, 1, 1, 1, 1, 1]

    # get predictions and activations
    print(f"GETTING ACTIVATIONS FOR {n_points} POINTS...")
    for idx in tqdm(range(n_points)):
        for i in range(12):
            for j in range(n_samples[i]):
                preds[idx][i][j] = get_model_pred(generations[str(idx)][str(i)][j], dname, metric=None, ref=None, cot=args.cot)
                if "gemma" in args.model:
                    prompt = get_prompt(ds, idx, prompt_type=i, gemma_model=True, dname=args.dataset, d_util=ds_util, cot=args.cot)
                else:
                    prompt = get_prompt(ds, idx, prompt_type=i, gemma_model=False, llama_model=True, dname=args.dataset, d_util=ds_util, cot=args.cot)
                ans_tokens = model.to_str_tokens(generations[str(idx)][str(i)][j], prepend_bos=False)
                answer_idx = get_answering_idx(ans_tokens)

                if answer_idx == -1:
                    preds[idx][i][j] = -1
                if answer_idx != -1 and len(ans_tokens[answer_idx].strip()) == 1 and ord(ans_tokens[answer_idx].strip()) >= ord("A") and ord(ans_tokens[answer_idx].strip()) <= ord("E"):
                    preds[idx][i][j] = ord(ans_tokens[answer_idx].strip()) - ord("A")

                for k in range(answer_idx):
                    prompt += ans_tokens[k]
                _, cache = model.run_with_cache(prompt)
                for l in range(len(sae_names)):
                    model_raw_acts = cache[sae_names[l]]
                    model_raw_acts = model_raw_acts[:, -1, :] if answer_idx != -1 else model_raw_acts[:, -(answer_idx):-1, :].mean(dim=1).to(device)
                    if i < 2:
                        raw_acts_01[idx][i][j][l] = model_raw_acts.detach()
                    elif i >= 2 and i <= 5:
                        raw_acts_25[idx][i-2][j][l] = model_raw_acts.detach()
                    else:
                        raw_acts_611[idx][i-6][j][l] = model_raw_acts.detach()
                clean_gpus()
    
    save_acts(raw_acts_01, dname, args.model, ans_or_cot="cot", starting_idx=0, cot=True, save_suffix="_01")
    save_acts(raw_acts_25, dname, args.model, ans_or_cot="cot", starting_idx=0, cot=True, save_suffix="_25")
    save_acts(raw_acts_611, dname, args.model, ans_or_cot="cot", starting_idx=0, cot=True, save_suffix="_611")
    
    save_dir = args.result_dir
    cot_path = "_cot" if args.cot else ""
    save_dir += f"{dname}_{args.model}{cot_path}"
    save_fname = save_dir + f"/preds.pkl"
    with open(save_fname, "wb") as f:
        pickle.dump(preds, f)

if __name__ == "__main__":
    main()
