from datasets import load_dataset, concatenate_datasets
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_lens import HookedSAETransformer, SAE
from transformer_lens import HookedTransformer
from pathlib import Path
import pandas as pd
import requests
import re
import gc
import torch

def clean_gpus() -> None:
    gc.collect()
    torch.cuda.empty_cache()

def load_ds(dname, cache_dir, cot=False):
    
    if dname == "mmlu":
        # experiment with the high school level splits
        mmlu_tasks = ['abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 
              'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 
              'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 
              'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 
              'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 
              'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 
              'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 
              'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 
              'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 
              'sociology', 'us_foreign_policy', 'virology', 'world_religions']
        mmlu_hs_tasks = [item for item in mmlu_tasks if "high_school" in item]
        mmlu_hs_datasets = []
        for item in mmlu_hs_tasks:
            mmlu_hs_datasets.append(load_dataset("cais/mmlu", item, split="test", cache_dir=cache_dir))
        dataset = concatenate_datasets(mmlu_hs_datasets)
        ds = dataset.shuffle(seed=1000)
        return ds, None
    
    elif dname == "medmcqa":
        dataset = load_dataset("openlifescienceai/medmcqa", split="validation", cache_dir=cache_dir)
        ds = dataset.shuffle(seed=1000)
        return ds, None

    elif dname == "commonsenseqa":
        dataset = load_dataset("tau/commonsense_qa", split="validation", cache_dir=cache_dir)
        ds = dataset.shuffle(seed=1000)
        return ds, None
    
    elif dname == "hellaswag":
        dataset = load_dataset("DatologyAI/hellaswag", split="eval", cache_dir=cache_dir)
        ds = dataset.shuffle(seed=1000).select(range(3000))
        return ds, None
    
    return None, None

# prompt utils
def prompt_add_gemma_template(prompt, dname):
    return "<bos><start_of_turn>user\n" + prompt + "<end_of_turn>\n<start_of_turn>model\n"

def prompt_add_llama_template(prompt):
    return "<|start_header_id|>user<|end_header_id|>\n\n" + prompt + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

def prompt_construction(question, answers, dname):
    if dname == "mmlu" or dname == "medmcqa":
        prompt = ''
        prompt += f"Question: {question}\n\n"
        prompt += f"Candidate answers:\nA: {answers[0]}\nB: {answers[1]}\nC: {answers[2]}\nD: {answers[3]}\n\n"
        return prompt
    elif dname == "commonsenseqa":
        prompt = ''
        prompt += f"Question: {question}\n\n"
        prompt += f"Candidate answers:\nA: {answers[0]}\nB: {answers[1]}\nC: {answers[2]}\nD: {answers[3]}\nE: {answers[4]}\n\n"
        return prompt
    elif dname == "hellaswag":
        prompt = 'Complete the following sentence with one of the candidate answer below:\n'
        prompt += f"Sentence: {question}\n\n"
        prompt += f"Candidate answers:\nA: {answers[0]}\nB: {answers[1]}\nC: {answers[2]}\nD: {answers[3]}\n\n"
        return prompt

# prompt templates
def get_prompt(d, idx, prompt_type=0, gemma_model=True, llama_model=False, dname="mmlu", d_util=None, cot=False):
    
    prompt = ""
    if dname == "mmlu" or dname == "medmcqa" or dname == "commonsenseqa" or dname == "hellaswag":
        example = d[idx]
        question = example["question"]
        answers = [0, 0, 0, 0]
        if dname == "mmlu" or dname == "hellaswag":
            answers = example["choices"]
        elif dname == "medmcqa":
            answers = [example['opa'], example['opb'], example['opc'], example['opd']]
        elif dname == "commonsenseqa":
            answers = example['choices']['text']
        if prompt_type == 0:
            prompt += "You are a helpful AI assistant, answer the following question:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Think step by step. Briefly justify your reasoning process, then put your final chosen answer in the form: [The answer is: (X)] at the end."
            else:
                prompt += "First put your answer in the form (X), then add a brief justification for your choice of answer."

        elif prompt_type == 1:
            prompt += "You are a knowledgeable helper, look at the following question:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Let's break this question down step by step. Write some short explanations for your reasoning, then put your answer in the form: [The answer is: (X)] at the end of your response."
            else:              
                prompt += "Please put your answer in the form (X) at the start of your response, then add a short explanation for your answer."
            
        elif prompt_type == 2:
            prompt += "You are an expert in multiple choice questions, answer the following question concisely:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Think about the question step by step. Provide some brief explanations for your thinking process. Put your answer in the form: [The answer is: (X)] to the end."
            else:
                prompt += "Put your answer in the form (X) first, then add a brief explanation of why you chose answer."

        elif prompt_type == 3:
            prompt += "You are a helpful AI assistant, answer this question:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Think step by step about this question. Add a brief justification for your choice of answer. Output your answer in the form: [The answer is: (X)] at the end of your response."
            else:
                prompt += "Output your answer in the form (X) at the start of your response, then add a brief justification for your choice of answer."

        elif prompt_type == 4:
            prompt += "Answer the following question:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Let's think step by step. Provide short explanations of your thinking steps. At the end of your response, put your choice of answer in the form: [The answer is: (X)]."
            else:
                prompt += "First put your answer in the form (X), then add a brief explanation for this answer."
            
        elif prompt_type == 5:
            prompt += "Here's a question I need you to help with:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Let's break down this question and think step by step. Briefly outline your reasoning process. Output your choice of answer with the form: [The answer is: (X)] to the end."
            else:
                prompt += "Output your choice of answer with the form (X) first, then briefly explain your answer."
            
        elif prompt_type == 6:
            prompt += "Look at the following question and answer it:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Think step by step. List out your thinking. Keep it short. Put your answer in the form: [The answer is: (X)] at the end of your response."
            else:
                prompt += "Put your answer in the form (X) at the start of your response, then justify your answer."

        elif prompt_type == 7:
            prompt += "I have a multiple choice question which you are going to help with:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Let's think slowly and step by step. First briefly output your thinking process with short justifications, then finally output your answer in the form: [The answer is: (X)]."
            else:
                prompt += "First, output your answer in the form (X), then provide some justifications for why this answer is correct."

        elif prompt_type == 8:
            prompt += "Please help me answer the following question:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Look at the question step by step. Explain your thoughts very briefly and finally output the answer in the form: [The answer is: (X)]."
            else:
                prompt += "Output the answer in this format (X) first, then explain your answer."

        elif prompt_type == 9:
            prompt += "Which candidate answer do you think is correct for this question:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Consider this question step by step with short explanations for your thoughts, then put your answer in the form: [The answer is: (X)] at the end of your response."
            else:
                prompt += "Put your answer in the form (X) at the start of your response, then add a short explanation for your answer."

        elif prompt_type == 10:
            prompt += "Here is a question in the multiple choice form with four potential answers:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Analyse the question and candidate answers with step-by-step thinking, then state the correct answer in the form: [The answer is: (X)] at the end of your outputs."
            else:
                prompt += "Answer your choice of candidate answer in the form (X) at the start of your response, then brifly justify why it should be correct."

        elif prompt_type == 11:
            prompt += "Below is a multiple choice question. Look at the question and the candidate answers, select the correct one:\n"
            prompt += prompt_construction(question, answers, dname)
            if cot:
                prompt += "Think about it step by step, present short explanations for your thoughts. At the end of your output, state your answer in the form: [The answer is: (X)]."
            else:
                prompt += "At the start of your response, put your chosen answer in the form (X), then provide some justifications for it."

    if gemma_model:
        return prompt_add_gemma_template(prompt, dname)
    if llama_model:
        return prompt_add_llama_template(prompt)
    return prompt

def load_model(mname, cache_dir, device="cuda", dtype="float16", n_devices=1):
    if "27" in mname:
        model = HookedTransformer.from_pretrained_no_processing("gemma-2-27b-it", device=device, cache_dir=cache_dir, dtype=dtype)
        return model
    model = HookedSAETransformer.from_pretrained(mname, device=device, n_devices=n_devices, cache_dir=cache_dir, dtype=dtype)
    return model

def get_model_pred(response, dname, metric=None, ref=None, cot=False):
    if dname == "mmlu" or dname == "medmcqa" or dname == "commonsenseqa" or dname == "hellaswag":
        if cot:
            pattern = r"the.{0,10} answer is:.{0,6}[a-e]"
            response = response.lower()
            find_match = re.search(pattern, response)
            if find_match is not None:
                if "e" in find_match.group().split(":")[-1]:
                    return 4
                if "d" in find_match.group().split(":")[-1]:
                    return 3
                if "c" in find_match.group().split(":")[-1]:
                    return 2
                if "b" in find_match.group().split(":")[-1]:
                    return 1
                if "a" in find_match.group().split(":")[-1]:
                    return 0
            pattern = r"answer:.{0,5}[a-e]"
            find_match = re.search(pattern, response)
            if find_match is not None:
                if "e" in find_match.group().split(":")[-1]:
                    return 4
                if "d" in find_match.group().split(":")[-1]:
                    return 3
                if "c" in find_match.group().split(":")[-1]:
                    return 2
                if "b" in find_match.group().split(":")[-1]:
                    return 1
                if "a" in find_match.group().split(":")[-1]:
                    return 0
            return -1
        else:
            if "(D)" in response.upper():
                return 3
            if "(C)" in response.upper():
                return 2
            if "(B)" in response.upper():
                return 1
            if "(A)" in response.upper():
                return 0
            find_match = re.search("\*\*[ABCD]", response)
            if find_match is not None:
                if "D" in find_match.group():
                    return 3
                if "C" in find_match.group():
                    return 2
                if "B" in find_match.group():
                    return 1
                if "A" in find_match.group():
                    return 0
            find_match = re.search("\*\* [ABCD]", response)
            if find_match is not None:
                if "D" in find_match.group():
                    return 3
                if "C" in find_match.group():
                    return 2
                if "B" in find_match.group():
                    return 1
                if "A" in find_match.group():
                    return 0
            find_match = re.search("[ABCD]:", response)
            if find_match is not None:
                if "D" in find_match.group():
                    return 3
                if "C" in find_match.group():
                    return 2
                if "B" in find_match.group():
                    return 1
                if "A" in find_match.group():
                    return 0
            find_match = re.search("[ABCD]\)", response)
            if find_match is not None:
                if "D" in find_match.group():
                    return 3
                if "C" in find_match.group():
                    return 2
                if "B" in find_match.group():
                    return 1
                if "A" in find_match.group():
                    return 0
            return -1
    return -1

# function to retrieve the token location of answering index in an answer
def get_answering_idx(tokens):
    # find the last occurrence of the predicted answer
    idx = -1
    r = len(tokens) - 1
    found = False
    while not found and r >= 10:
        if tokens[r].lower().strip() in ("a", "b", "c", "d", "e"):
            if r < 10:
                break
            this_list = [item.strip().lower() for item in tokens[r-10:r]]
            if "answer" in this_list and (":" in this_list or ":**" in this_list):
                found = True
                idx = r
                break
        r -= 1
    return idx

def load_one_sae_with_descriptions(sae_release, sae_id, model_id="gemma-2-9b", token="******", device="cuda:0"):
    # get SAE
    sae, _, _ = SAE.from_pretrained(
        release=sae_release,  # see other options in sae_lens/pretrained_saes.yaml
        sae_id=sae_id,  # won't always be a hook point
        device=device)

    # get corresponding descriptions
    release = get_pretrained_saes_directory()[sae_release]
    np_id = release.neuronpedia_id[sae_id].split("/")[-1]
    querystring = {"modelId":model_id,"saeId":np_id}
    url = "https://www.neuronpedia.org/api/explanation/export"
    headers = {"X-Api-Key": token}
    response = requests.get(url, headers=headers, params=querystring).json()
    df = pd.DataFrame(response)
    return sae, df

def load_saes(sae_release, sae_ids, model_id="gemma-2-9b", token="******", device="cuda:0"):
    saes, dfs = [], []
    for id in sae_ids:
        sae, df = load_one_sae_with_descriptions(sae_release, id, model_id=model_id, token=token, device=device)
        saes.append(sae)
        dfs.append(df)
    return saes, dfs

def save_acts(acts, dname, mname, ans_or_cot="ans", starting_idx=0, save_dir="******", cot=False, save_suffix=""):
    num = int(acts.shape[0] / 500) + int(bool(acts.shape[0] % 500))
    cot_path = "_cot" if cot else ""
    save_dir += f"/{dname}_{mname}{cot_path}"
    dir_path = Path(save_dir)
    if not dir_path.is_dir():
        dir_path.mkdir()
    for i in range(num):
        end_idx = min(i*500+500, acts.shape[0])
        save_fname = save_dir + f"/{ans_or_cot}{save_suffix}_{i+starting_idx}.pt"
        print(f"SAVING TO: {save_fname}")
        torch.save(acts[i*500:end_idx], save_fname)

