from utils import generate_digit_and_special_tokens, decode, manually_generate, whitelist_tokens, move_to_cuda, load_model
from transformers import OPTForCausalLM, AutoTokenizer
from tqdm import tqdm
import torch
import os
import logging
from collections import defaultdict
import copy
from IPython import embed
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import gc
from timer import Timer


DEBUG_MODE = True
COMPUTE_MODE = True


if not DEBUG_MODE:
    logging.basicConfig(filename="logs/addition.log",
                    filemode='a',
                    format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.DEBUG)
else:
    logging.getLogger().setLevel(logging.INFO)


MODEL_SIZES = ["125m", "350m", "1.3b", "2.7b", "6.7b", "13b", "30b", "66b"]
TOP_K_RANGES = [1, 2, 4, 8, 16, 32, 100, 500, 1000, 2000]

def generate_dataset(N, reps, batch_size):
    base = "You are an expert at addition. 1+1=2. 6+8=14. "
    low = 10**(N-1); high = 10**N

    for i in range(reps):
        np.random.seed(i)

        prompts, answers = [], []
        for _ in range(batch_size):
            a = np.random.randint(low, high - 1)
            b = np.random.randint(low, high - 1)
            test_str = base + str(a) + "+" + str(b) + "="
            ans = str(a+b) + ". "

            prompts.append(test_str)
            answers.append(test_str + ans)
        
        yield prompts, answers

def test_n_digit_addition(model, tokenizer, N, reps=6, batch_size=16):
    assert isinstance(N, int), "N is not an integer"

    greedy_acc = defaultdict(float)

    for prompts, answers in generate_dataset(N, reps, batch_size):
        greedy_exact_results = run_test(model, tokenizer, prompts, answers, "greedy_exact")
        greedy_results = run_test(model, tokenizer, prompts, answers, "greedy")
        sample_results = run_test(model, tokenizer, prompts, answers, "sample")

        for i, topk in enumerate(TOP_K_RANGES):
            greedy_acc[topk] += np.mean(greedy_exact_results[i])
        
        # test_generations(model, tokenizer, test_str)

    # Divide every value in greedy_acc by reps
    return {k: v / reps for k, v in greedy_acc.items()}

def run_test(model, tokenizer, prompts, prompts_and_answers, test_type):

    prompt_enc = tokenizer(prompts, return_tensors="pt", padding=True)
    all_enc = tokenizer(prompts_and_answers, return_tensors="pt", padding=True)
    logits = get_logits(model, all_enc, DIGIT_TOKENS=tokenizer.DIGIT_TOKENS)
    start_idxs = [len(prompt_enc["input_ids"][i]) for i in range(len(prompts))]

    if test_type == "sample":
        return np.exp(get_log_prob_of_sequence(logits, all_enc, start_idxs, tokenizer))
    elif "greedy" in test_type:
        greedy_results = []
        for topk in TOP_K_RANGES:
            accs = greedy_decode_accuracy(logits, all_enc, start_idxs, top_k=topk, tokenizer=tokenizer)
            if "exact" in test_type:
                accs[accs < 1 - 1e-5] = 0.
            
            greedy_results.append(accs)

        return greedy_results

def get_logits(model, input, DIGIT_TOKENS):
    with torch.no_grad():

        input = move_to_cuda(input)
        outputs = model(**input, return_dict=True)
        logits = outputs.logits

        # Force all logits except those indexed by DIGIT_TOKENS to be -np.inf
        if DIGIT_TOKENS is not None:
            logits[:, :, list(set(range(logits.shape[-1])) - DIGIT_TOKENS)] = -np.inf
        
        return logits

def greedy_decode_accuracy(logits, input, start_idxs, top_k, tokenizer):

    batch_size = logits.shape[0]
    accs = np.zeros(batch_size)

    for idx, start_idx in enumerate(start_idxs):
        total_len = 0
        for i, id in enumerate(input["input_ids"][idx]):
            if i < start_idx or id == tokenizer.pad_id:
                continue

            total_len += 1
            if id in torch.topk(logits[idx, i-1, :], top_k).indices.tolist():
                accs[idx] += 1

        accs[idx] /= total_len
        assert 0 <= accs[idx] <= 1, "Accuracy is not between 0 and 1"
    return accs

def get_log_prob_of_sequence(logits, input, start_idxs, tokenizer):

    batch_size = logits.shape[0]
    final_log_prob = np.zeros(batch_size)

    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    for idx, start_idx in enumerate(start_idxs):
        for i, id in enumerate(input["input_ids"][idx]):
            if i < start_idx or id == tokenizer.pad_id:
                continue

            final_log_prob[idx] += log_probs[0, i, id]
    
    return final_log_prob

def test_generations(model, tokenizer, input_str):
    inputs = move_to_cuda(tokenizer(input_str, return_tensors="pt"))
    library_outputs = model.generate(**inputs, max_new_tokens=3, return_dict_in_generate=True, output_scores=True)
    logging.info("#############################################")
    logging.info("Library outputs: {}".format(decode(tokenizer, library_outputs["sequences"])))
    # manual_outputs, log_probs = manually_generate(model, tokenizer, inputs, tokenizer.DIGIT_TOKENS, maxlen=3)
    # logging.info("Manual outputs: {} || log_probs {}".format(decode(tokenizer, manual_outputs), log_probs))

def main():

    data = {}
    timer = Timer(print_results=True)

    # If results csv exists, load and don't run
    if not COMPUTE_MODE and os.path.exists("logs/results.csv"):
        print("Results csv exists, loading and plotting")
        df = pd.read_csv("logs/results.csv")
        plot_data(df)
        quit()

    for i, model_size in tqdm(enumerate(MODEL_SIZES)):

        model, tokenizer = load_model(model_size)
        timer.snap("Loaded {} model".format(model_size))

        tokenizer.DIGIT_TOKENS = whitelist_tokens(tokenizer)
        tokenizer.pad_id = 1
        timer.snap("Finished tokenizing")

        for digits in range(1, 11):
            data[(model_size, digits)] = test_n_digit_addition(model, tokenizer, digits)
            timer.snap("Finished {} digit addition".format(digits))
        
        # Clear up model memory
        del model, tokenizer
        gc.collect()
        torch.cuda.empty_cache()
        timer.snap("Cleaned up {} model.".format(model_size))

        df = save_data(data)
        line_chart(df)

def stacked_bar_chart(df):
    plt.clf()
    plt.xlabel("Model Size")
    plt.ylim(0, 1)
    plt.ylabel("Accuracy")
    plt.title("Accuracy of OPT Models on Single Digit Addition")

    previous = 0
    for topk in TOP_K_RANGES:
        plot_data = df["greedy_acc_{}".format(topk)] - previous
        plt.bar(df["model_size"], plot_data, label="{}".format(topk), bottom=previous)
        previous += plot_data

    plt.legend()
    plt.savefig("logs/plot.png")

def line_chart(df):

    for topk in TOP_K_RANGES:
        plt.clf()
        plt.close('all')

        plt.xlabel("Digits in Addition")
        plt.ylim(0, 1)
        plt.ylabel("Accuracy")
        plt.title("Accuracy of OPT Models on Addition")

        for model_sz in MODEL_SIZES:

            relevant_df = df[df["model_size"] == model_sz]

            if len(relevant_df) == 0:
                continue
                
            plt.plot(relevant_df["digits"], relevant_df["greedy_acc_{}".format(topk)], label=model_sz)

        plt.legend()
        plt.savefig("logs/line_chart_top_{}.png".format(topk))

def save_data(data):

    # Convert data into pandas dataframe
    data_as_list = []

    for key, value in data.items():
        model_size, digits = key
        data_as_list.append([model_size, digits] + list(value.values()))
    df = pd.DataFrame(data_as_list, columns=["model_size", "digits"] + ["greedy_acc_{}".format(topk) for topk in TOP_K_RANGES])
    df.to_csv("logs/multi_digit_results.csv")
    return df

if __name__ == "__main__":
    main()