import torch
import ujson as json
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from itertools import product
import pickle

letters = [f" {chr(i)}" for i in range(65, 91)]
numbers_en = [
    " zero",
    " one",
    " two",
    " three",
    " four",
    " five",
    " six",
    " seven",
    " eight",
    " nine",
]
animal_choice = [" cat", " dog", " tiger", " lion"]
# Generate all possible combinations
combinations = ["".join(comb) for comb in product(letters, numbers_en, animal_choice)]


def get_first_token_logits(model, tokenizer, prompt, vocab_size, device):
    """Gets the logits for the first token generated by the model."""

    # Step 1: Encode the prompt (tokenize it)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Step 2: Forward pass through the model
    outputs = model(**inputs)

    # Step 3: Get the logits (logits are usually the first element of the model output tuple)
    logits = outputs.logits

    # Step 4: Extract the logits for the first token
    first_token_logits = logits[:, -1, :vocab_size].cpu()

    return first_token_logits


def get_logits_dict(
    prompt: str,
    context_window: str,
    generation_model: AutoModelForCausalLM,
    generation_tokenizer: AutoTokenizer,
    vocab_size: int,
):
    """Get the logits dictionary for the current context window"""
    # Get the ids of the prompt
    input_ids = generation_tokenizer.encode(
        prompt, return_tensors="pt", add_special_tokens=True
    ).to(device)

    # Get the last segment of the input string & ids
    post_input_str = prompt.rsplit("\n", 1)[-1]
    post_input_ids = generation_tokenizer.encode(
        post_input_str, add_special_tokens=False
    )

    # Get the last segment of the input string & ids
    main_prompt_str = post_input_str + context_window
    main_prompt_ids = generation_tokenizer.encode(
        main_prompt_str, add_special_tokens=False
    )

    # Get the ids of the context window (use main_prompt_ids length to slice to prevent different tokenzier processing of leading spaces)
    context_window_ids = main_prompt_ids[len(post_input_ids) :]

    print(f"Length of context window ids: {len(context_window_ids)}")

    gen_logits = []

    with torch.no_grad():
        for id in context_window_ids:
            # Get the current input_ids
            input_ids = torch.cat([input_ids, torch.tensor([[id]]).to(device)], dim=-1)
            # Generate logits
            output = generation_model(input_ids)
            logits = output.logits[:, -1, :vocab_size].cpu().detach()
            gen_logits.append(logits)

    def create_nested_dict(ids, logits):
        if not ids:
            return {}

        current_id = ids[0]

        # If the ID array only has one element, create a dictionary containing 'logits'
        if len(ids) == 1:
            return {current_id: {"logits": logits[0]}}

        # Recursively call to create the next level nested dictionary
        return {
            current_id: {"logits": logits[0], **create_nested_dict(ids[1:], logits[1:])}
        }

    nested_dict = create_nested_dict(context_window_ids, gen_logits)
    return nested_dict


def merge_nested_dicts(dict1, dict2):
    """
    Recursively merge two nested dictionaries, with the result saved in dict1.
    """
    for key in dict2:
        if key in dict1:
            if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
                # If both dictionaries have sub-dictionaries at the same key, recursively merge them
                merge_nested_dicts(dict1[key], dict2[key])
            else:
                # If the keys are the same but the values are not dictionaries, directly overwrite dict1's value with dict2's value
                dict1[key] = dict2[key]
        else:
            # If dict1 doesn't have this key, add it directly
            dict1[key] = dict2[key]


def update_root_dict_with_nested(root_dict, nested_dict):
    """
    Update the root dictionary with the merged nested dictionary.
    """
    merge_nested_dicts(root_dict, nested_dict)
    return root_dict


# Main function
def run(model_name: str):
    logits_file_path_1 = f"../../data/logits/ngram-p1-logits-{model_name}.pickle"
    logits_file_path_2 = f"../../data/logits/ngram-p2-logits-{model_name}.pickle"
    logits_paths = [logits_file_path_1, logits_file_path_2]
    prompt_file_path_1 = f"../../data/prompts/ngram-p1.txt"
    prompt_file_path_2 = f"../../data/prompts/ngram-p2.txt"

    with open(prompt_file_path_1, "r") as f:
        prompt1 = f.readlines()
        prompt1 = "".join(prompt1)

    with open(prompt_file_path_2, "r") as f:
        prompt2 = f.readlines()
        prompt2 = "".join(prompt2)

    prompts = [prompt1, prompt2]

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, device=device)
    if model_name in ["opt27b", "opt13b"]:
        vocab_size = 50272
    else:
        vocab_size = tokenizer.vocab_size

    # Model
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")

    for idx, logits_file_path in enumerate(logits_paths):
        """Get the logits vector D' distribution"""
        with torch.no_grad():
            results = {}

            # Get the logits of the original prompt
            D_original = get_first_token_logits(
                model, tokenizer, prompts[idx], vocab_size, device
            )
            results["logits"] = D_original

            for item in combinations:
                print(f"Getting logits for: {item}")
                # Get the nested logits dictionary for this combination
                nested_dict = get_logits_dict(
                    prompts[idx], item, model, tokenizer, vocab_size
                )
                # Update to the root dictionary
                update_root_dict_with_nested(results, nested_dict)

        # Write to pickle file
        os.makedirs(os.path.dirname(logits_file_path), exist_ok=True)
        with open(logits_file_path, "wb") as file:
            pickle.dump(results, file)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, default="opt27b")
    parser.add_argument("--device", type=str, default="5")
    parser.add_argument("--model_path", type=str)
    args = parser.parse_args()
    
    model_path = args.model_path

    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.device}"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(args)
    run(args.model_name)
