import logging
import hydra
from omegaconf import OmegaConf
import cramming
import torch
from safetensors.torch import load_file
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import re
import pandas as pd
import datasets
import os
from typing import List, Dict
from cramming.data.tokenizer_preparation import get_tokenizer
from cramming.data.arithmetic_tokenizers import CustomCharLevelTokenizerForAddingPadding_Base100, CustomCharLevelTokenizerForAddingPadding_Base1000
import random
from docx import Document
from docx.shared import RGBColor

log = logging.getLogger(__name__)

def load_checkpoint(model, file, device):
    model_state = load_file(file, device=device)
    if "encoder.embedding.word_embedding.weight" not in model_state:
        # Hack to save space when saving the model, more clever though would be save the right one in the first place
        model_state["encoder.embedding.word_embedding.weight"] = model_state["decoder.weight"]
    try:
        sanitized_state = {}
        for k, v in model_state.items():
            if k.startswith("module."):
                k = k[7:]
            sanitized_state[k] = v
        model.load_state_dict(sanitized_state, strict=True)
        log.info("finished loading state dict")
    except RuntimeError as e:
        log.info(f"State dict difference is {str(e).split('Error(s) in loading state_dict for')[1]}... Ok?")
        model.load_state_dict(sanitized_state, strict=False)
    return model.to(device)

def line_plotter(accs, type="accs", large=False, ood_only=False):
    x_ticks = list(accs.keys())
    values = [i * 100 for i in list(accs.values())]
    plt.plot(x_ticks, values, marker='o')

    plt.xlabel("Number of Digits")
    y_label = "Accuracy" if (type=="accs" or type=="reverse") else "Percentage with answer contained in output"
    plt.ylabel(y_label)
    plt.title(y_label)
    plt.ylim(-1, 101)
    plt.xticks(x_ticks)
    plt.tight_layout()
    plt.savefig(f"{type}_{'large_' if large else ''}{'ood_only_' if ood_only else ''}line_plot", bbox_inches='tight')
    plt.clf()

def grid_plotter(data, type="accs", name='_large', extra_path=None):
    data = np.array(data)*100
    df = pd.DataFrame(data)

    # Create the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(df, annot=True, cmap="YlGnBu", fmt=".1f", annot_kws={'size': 8,'rotation':0})
    
    # Customize the plot
    plt.title("Accuracy - percetange, rounded to 1dp")
    plt.ylabel("1st Number Length")
    plt.xlabel("2nd Number Length")
    size = data.shape[0]
    plt.xticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))
    plt.yticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))

    if extra_path is not None:
        plt.savefig(f"{extra_path}{type}{name}_grid_plot", bbox_inches='tight')
    else:
        plt.savefig(f"{type}{name}_grid_plot", bbox_inches='tight')
    plt.clf()

def default_plotting(ax, x_range, fs=18):
    ax.set_xlabel("Time (in iterations)", fontsize=fs)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xticks(x_range)
    ax.set_xticklabels(x_range, fontsize=fs)

def plot_completion(data, ax, tokens, fs=18):
    ax.plot(data, linewidth=3)
    ax.set_yticks(list(range(len(tokens))))
    ax.set_yticklabels(tokens, fontsize=fs)

def plot_confidence(data_array, ax, n_times, tokens, fs=18):
    x_labels = list(range(n_times))
    bottom = np.zeros(n_times)
    colors = plt.colormaps.get_cmap('tab20')
    for j, row in enumerate(data_array.T):
        ax.bar(x_labels, row, bottom=bottom, label=tokens[j], color=colors(j))
        bottom += row  # Update bottom positions for the next category
    ax.legend()

def compare_lists_up_to_eos(list1, list2):
    eos_index1 = list1.index("[EOS]") if "[EOS]" in list1 else len(list1)
    eos_index2 = list2.index("[EOS]") if "[EOS]" in list2 else len(list2)
    min_eos_index = min(eos_index1, eos_index2)
    return list1[:min_eos_index] == list2[:min_eos_index]

def multi_plotter(data_dict: Dict, num: int, prompt: str, answer_str: str, completion: List[str], tokens: List[str] = None, n_locs: int = 0, n_times: int = 0, truncate:bool=True):
    answer = [*answer_str]
    answer.append("[EOS]")
    if truncate:
        n_locs = min(len(completion), len(answer))
    fs = 18
    sns.set_theme(style="darkgrid")
    fig, axs = plt.subplots(len(data_dict.items()), n_locs, figsize=(n_locs * 4, 16), sharey=False)
    # tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "$-$", "$\\times$", "=", ' ']

    label_freq = 2 if n_times <= 10 else 5
    x_range = list(range(0, n_times + 2, label_freq))  # for x labels
    row = 0
    for name, data in data_dict.items():
        for i in range(n_locs):
            default_plotting(axs[row, i], x_range, fs)

            title = f"Loc {i}"
            ans = f'ans: {answer[i]}' if i < len(answer) else ''
            comp = f'guess: {completion[i]}' if i < len(completion) else ''
            note = f'{ans} {comp}'
            title = f"{title} {f'({note})' if len(note) > 1 else ''}"
            if name == "gen_tokens":
                plot_completion(data[i], axs[row, i], tokens, fs)
                y_label = "Output Token"
            elif name == "logits":
                data_array = np.array(data[i])
                plot_confidence(data_array, axs[row, i], n_times, tokens, fs=fs)
                y_label = "Logits"
            else:
                print(f"INVALID DATA: {name}")
                return
            axs[row, i].set_title(title, fontsize=fs)
            axs[row, 0].set_ylabel(y_label, fontsize=fs)
        row += 1
    # fig.suptitle(f"Final Output: {''.join([tokens[s] for s in arr[:, -1]])}", fontsize=fs+12)
    is_correct = f'<Correct>' if compare_lists_up_to_eos(completion, answer) else f'<Wrong>'  # remove EOS from completion check

    fig.suptitle(f"Prompt: {prompt} Answer: {''.join(answer)} Output: {''.join(completion)} {is_correct}", fontsize=fs + 36)
    plt.tight_layout()
    plt.savefig(f"combined_plot_{num}", bbox_inches='tight')

def thinking_plotter(arr, num, prompt, answer, completion):
    n_locs, n_times = arr.shape

    fs = 18
    sns.set_theme(style="darkgrid")
    fig, axs = plt.subplots(1, n_locs, figsize=(n_locs*4, 8), sharey=False)
    tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "$-$", "$\\times$", "="]

    for i in range(n_locs):
        axs[i].plot(arr[i], linewidth=3)
        axs[i].set_xticks([0, 2, 4, 6, 8, 10])
        axs[i].set_xticklabels([0, 2, 4, 6, 8, 10], fontsize=fs)
        axs[i].spines["top"].set_visible(False)
        axs[i].spines["bottom"].set_visible(False)
        axs[i].spines["left"].set_visible(False)
        axs[i].spines["right"].set_visible(False)
        axs[i].set_title(f"Location {i}", fontsize=fs)
        axs[i].set_yticks(list(range(len(tokens))))
        axs[i].set_yticklabels(tokens, fontsize=fs)
        axs[i].set_xlabel("Time (in iterations)", fontsize=fs)
    axs[0].set_ylabel("Output Token", fontsize=fs)
    fig.suptitle(f"Prompt: {prompt} Answer: {answer} Output: {completion}", fontsize=fs+12)
    plt.tight_layout()
    plt.savefig(f"iteration_plot_{num}", bbox_inches='tight')
    plt.clf()

@torch.inference_mode()
def model_gen(model, *inputs, token_limit, temperature, **kwargs): # through forward call
    device_inputs = inputs[0] # tensor([[ 7,  1, 14,  1,  7,  1, 17]], device='cuda:0') i.e. removes it from tuple
    predicted_ids = []

    for gen_idx in range(token_limit):
        logits = model(device_inputs)["logits"]
        predicted_token = torch.multinomial(torch.softmax(logits * temperature, dim=-1), 1)
        device_inputs = torch.cat([device_inputs, predicted_token], dim=-1)
        predicted_ids += [predicted_token]
        outputs = torch.cat(predicted_ids, dim=-1)
    return outputs

def verbose_prompt_eval(device_inputs, answer, model, tokenizer, cfg_arch, token_limit=20, temp=0.7, crammed=False, greedy=False):
    """
    Tracks steps during generation
    """
    predicted_ids, tracking, logits = model._generate(device_inputs, token_limit=token_limit, temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, track_steps=True, greedy=greedy) # auto takes cfg_arch.maximal_recurrence_in_eval
    # decoded_completion = tokenizer.decode(predicted_ids[0].tolist())  # drop batch dim before decoding
    verbose_decoded_completion = [tokenizer._convert_id_to_token(token_id) for token_id in predicted_ids[0].tolist()]
    # verbose_decoded_completion = "".join(verbose_decoded_completion)
    np_tmp = []
    for j in range(0,len(tracking)):
        row = tracking[j]
        row_temp = [tensor.cpu().item() for tensor in row]
        np_tmp.append(row_temp)
    np_tmp = np.array(np_tmp)
    return np_tmp, verbose_decoded_completion, logits

def rouge_eval(tf, tf_text, device_inputs, answer, model, tokenizer, cfg_arch, token_limit=20, temp=0.7, crammed=False, greedy=False):
    """
    returns rouge-l tensorflow score
    """
    if crammed:
        predicted_ids = model_gen(model, device_inputs, token_limit=token_limit, temperature=temp)
    else:
        predicted_ids = model._generate(device_inputs, token_limit=token_limit, temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=greedy) # auto takes cfg_arch.maximal_recurrence_in_eval

    # removes everything after first EOS
    eos_id = tokenizer._convert_token_to_id(tokenizer.eos_token)
    index_of_eos = (predicted_ids == eos_id).nonzero(as_tuple=False)
    if index_of_eos.numel() > 0:
        # Truncate the tensor up to the first occurrence of 3
        predicted_ids= predicted_ids[:, :index_of_eos[0, 1] + 1]

    decoded_completion = tokenizer.decode(predicted_ids[0].tolist())  # drop batch dim before decoding
    decoded_completion = list(decoded_completion)

    gpu_device = '/gpu:0'
    with tf.device(gpu_device):
        hyp = tf.cast(tf.ragged.constant([decoded_completion]), dtype=tf.string)
        answer = list(answer)
        ref = tf.cast(tf.ragged.constant([answer]), dtype=tf.string)
        f, p, r = tf_text.metrics.rouge_l(hyp, ref)
    return f.numpy().tolist(), p.numpy().tolist(), r.numpy().tolist()

def reverse_numbers(input_string):
    pattern = r'\d+'
    def reverse_match(match):
        return match.group(0)[::-1]
    result = re.sub(pattern, reverse_match, input_string)
    return result

def pad_numbers(input_string, min_length=12):
    pattern = r'\d+'
    def pad(match):
        padded_number = match.group(0).rjust(min_length, '0')
        return padded_number

    result = re.sub(pattern, pad, input_string)
    return result

def tokenize_and_split_line(line, tokenizer, device, reverse_inputs=False, pad_zeros=0, remove_padding=False):
    try:
        prompt, answer = line.rsplit(" ", 1)
        prompt = prompt + " " # adding the removed space back onto prompt
    except:
        equal_sign_index = line.find('=')
        prompt = line[:equal_sign_index].replace('+', ' + ').strip() + " = "
        answer = line[equal_sign_index + 1:].strip()
    
    if isinstance(tokenizer, CustomCharLevelTokenizerForAddingPadding_Base1000) or isinstance(tokenizer, CustomCharLevelTokenizerForAddingPadding_Base100):
        answer = tokenizer._tokenize(answer)
        answer = ''.join(answer)
    if pad_zeros > 0:
        prompt = pad_numbers(prompt, pad_zeros)
    if reverse_inputs:
        prompt = reverse_numbers(prompt)
    if remove_padding:
        prompt = prompt.replace(" ","")
    tokenized_inputs = torch.as_tensor(tokenizer(prompt)["input_ids"], dtype=torch.long)[None, :]
    device_inputs = tokenized_inputs.to(device)
    return prompt, device_inputs, answer

def gdm_index_hints_helper(num, tokenizer):
    char_set = tokenizer.char_set
    shape1 = num.shape[1]
    for i in range(shape1):
        this_char_token = tokenizer._convert_token_to_id(char_set[i])
        char_to_insert = this_char_token * torch.ones((num.shape[0], 1), dtype=num.dtype, device=num.device)
        num = torch.cat((num[:,:(2*i)], char_to_insert, num[:,(2*i):]), dim=1)
    return num

def grid_logic(cfg):
    # origional testing
    def logic_func_large(data_size_1, data_size_2):
        return (data_size_1 <= 23 or data_size_2 <=23)
    logic_func = logic_func_large
    name = '_large'
    max_size = 23+1
    try:
        large = cfg.large
    except:
        large = True
    try:
        ood_only = cfg.ood_only
        def logic_func_ood(data_size_1, data_size_2):
            return (data_size_1 >=24 or data_size_2 >=24) and (data_size_1 <= 30 or data_size_2 <=30)
        logic_func = logic_func_ood
        name = '_ood_only'
        max_size = 30+1
    except:
        ood_only = False
    try:
        up_to_40 = cfg.up_to_40
        def logic_func_40(data_size_1, data_size_2):
            return (data_size_1 >=31 or data_size_2 >=31) and (data_size_1 <=40 or data_size_2 <=40)
        logic_func = logic_func_40
        name = '_up_to_40'
        max_size = 40+1
    except:
        up_to_40 = False
    try:
        up_to_50 = cfg.up_to_50
        def logic_func_50(data_size_1, data_size_2):
            return (data_size_1 >=41 or data_size_2 >=41) and (data_size_1 <=50 or data_size_2 <=50)
        logic_func = logic_func_50
        name = '_up_to_50'
        max_size = 50+1
    except:
        up_to_50 = False

    # checkerboarding: for the large eval we can checkerboard:
    try:
        checkerboard = cfg.checkerboard
    except:
        checkerboard = None

    if checkerboard is not None:
        if checkerboard == 'even':
            def checkerboard_even(data_size_1, data_size_2):
                return ((data_size_1+data_size_2)%2 ==0)
            checkerboard_func = checkerboard_even
            checkerboard_str = "_even"
        elif checkerboard == 'odd':
            def checkerboard_odd(data_size_1, data_size_2):
                return ((data_size_1+data_size_2)%2 ==1)
            checkerboard_func = checkerboard_odd
            checkerboard_str = "_odd"
        else:
            print("checkerboard config not allowed")
            exit()
    else:
        def always_true(data_size_1, data_size_2):
            return True
        checkerboard_func = always_true
        checkerboard_str = ""


    # if we are testing up to 100, split into 10 steps each of approximately equal number of forward passes required
    try:
        big_eval_step_1 = cfg.big_eval_step_1 # 1 -> 46
        def logic_func_big_1(data_size_1, data_size_2):
            return (data_size_1 <= 46 and data_size_2 <= 46) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_1
        name = '_big_eval_1'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_1 = False
    try:
        big_eval_step_2 = cfg.big_eval_step_2 # 47 -> 58
        def logic_func_big_2(data_size_1, data_size_2):
            # return (data_size_1  == 50 and data_size_2 == 50) 
            # with padding = 67%
            # without padding = 48 %
            # random padding = 48%
            return (data_size_1 >=47 or data_size_2 >=47) and (data_size_1 <=58 and data_size_2 <=58) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_2
        name = '_big_eval_2'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_2 = False
    try:
        big_eval_step_3 = cfg.big_eval_step_3 # 59 -> 67
        def logic_func_big_3(data_size_1, data_size_2):
            return (data_size_1 >=59 or data_size_2 >=59) and (data_size_1 <=67 and data_size_2 <=67) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_3
        name = '_big_eval_3'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_3 = False
    try:
        big_eval_step_4 = cfg.big_eval_step_4 # 68 -> 74
        def logic_func_big_4(data_size_1, data_size_2):
            return (data_size_1 >=68 or data_size_2 >=68) and (data_size_1 <=74 and data_size_2 <=74) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_4
        name = '_big_eval_4'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_4 = False
    try:
        big_eval_step_5 = cfg.big_eval_step_5 # 75 -> 80
        def logic_func_big_5(data_size_1, data_size_2):
            return (data_size_1 >= 75 or data_size_2 >=75) and (data_size_1 <=80 and data_size_2 <=80) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_5
        name = '_big_eval_5'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_5 = False
    try:
        big_eval_step_6 = cfg.big_eval_step_6 # 81 -> 85
        def logic_func_big_6(data_size_1, data_size_2):
            return (data_size_1 >= 81 or data_size_2 >=81) and (data_size_1 <=85 and data_size_2 <=85) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_6
        name = '_big_eval_6'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_6 = False
    try:
        big_eval_step_7 = cfg.big_eval_step_7 # 86 -> 90
        def logic_func_big_7(data_size_1, data_size_2):
            return (data_size_1 >= 86 or data_size_2 >=86) and (data_size_1 <=90 and data_size_2 <=90) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_7
        name = '_big_eval_7'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_7 = False
    try:
        big_eval_step_8 = cfg.big_eval_step_8 # 91 -> 94
        def logic_func_big_8(data_size_1, data_size_2):
            return (data_size_1 >= 91 or data_size_2 >=91) and (data_size_1 <=94 and data_size_2 <=94) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_8
        name = '_big_eval_8'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_8 = False
    try:
        big_eval_step_9 = cfg.big_eval_step_9 # 95 -> 97
        def logic_func_big_9(data_size_1, data_size_2):
            return (data_size_1 >= 95 or data_size_2 >=95) and (data_size_1 <=97 and data_size_2 <=97) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_9
        name = '_big_eval_9'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_9 = False
    try:
        big_eval_step_10 = cfg.big_eval_step_10 # 98 -> 100
        def logic_func_big_10(data_size_1, data_size_2):
            return (data_size_1 >= 98 or data_size_2 >=98) and (data_size_1 <=100 and data_size_2 <=100) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_10
        name = '_big_eval_10'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_10 = False


    # boolean_list_precidence = [large, ood_only, up_to_40, up_to_50, big_eval_step_1, big_eval_step_2, big_eval_step_3, big_eval_step_4, big_eval_step_5]

    log.info(f"large = {large}")
    log.info(f"ood only = {ood_only}")
    log.info(f"up to 40 = {up_to_40}")
    log.info(f"up to 50 = {up_to_50}")
    log.info(f"big eval 1 = {big_eval_step_1}")
    log.info(f"big eval 2 = {big_eval_step_2}")
    log.info(f"big eval 3 = {big_eval_step_3}")
    log.info(f"big eval 4 = {big_eval_step_4}")
    log.info(f"big eval 5 = {big_eval_step_5}")
    log.info(f"big eval 6 = {big_eval_step_6}")
    log.info(f"big eval 7 = {big_eval_step_7}")
    log.info(f"big eval 8 = {big_eval_step_8}")
    log.info(f"big eval 9 = {big_eval_step_9}")
    log.info(f"big eval 10 = {big_eval_step_10}")
    log.info(f"the last true value in the above list will be run, mul and pos arith can take control after this")

    return logic_func, name, max_size

def main(cfg):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    local_checkpoint_folder = os.path.join(cfg.base_dir, cfg.name, "checkpoints")
    tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg.eval.checkpoint,
                                                                                local_checkpoint_folder,
                                                                                cfg.eval.arch_modifications)
    try:
        cfg_arch.mask_before_equals = cfg_arch.mask_before_equals
    except:
        cfg_arch.mask_before_equals = False

    try:
        cfg_arch.loss_reduction = cfg_arch.loss_reduction
    except:
        cfg_arch.loss_reduction = 'mean'

    try:
        cfg_arch.maximal_recurrence_in_eval = cfg.max_rec
    except:
        cfg_arch.maximal_recurrence_in_eval = 10
    log.info(f"cfg_arch.maximal_recurrence_in_eval changed to {cfg_arch.maximal_recurrence_in_eval}")

    try:
        cfg_arch.attention.max_length = cfg_arch.attention.max_length
    except:
        cfg_arch.attention.max_length = 0

    if 'forward_only_model_with_skip' not in cfg_arch: # trained before we corrected the typo
        try:
            cfg_arch.forward_only_model_with_skip = cfg_arch.forward_only_model_with_sikp
        except:
            cfg_arch.forward_only_model_with_skip = False
        cfg_arch.pop("forward_only_model_with_sikp")
    log.info(f"forward_only_model_with_skip = {cfg_arch.forward_only_model_with_skip}")

    try:
        cfg_arch.embedding.max_recycle_len = cfg_arch.embedding.max_recycle_len 
    except:
        cfg_arch.embedding.max_recycle_len = 100
    log.info(f"embedding.max_recycle_len = {cfg_arch.embedding.max_recycle_len}")

    cfg_arch.throttle=False

    try:
        reverse_inputs = cfg.reverse_inputs
    except:
        reverse_inputs = False
    log.info(f"reverse inputs = {reverse_inputs}")

    try:
        pad_zeros = cfg.pad_zeros
    except:
        pad_zeros = 0
    log.info(f"pad zeros = {pad_zeros}")

    try:
        advanced_plotting = cfg.advanced_plotting
    except:
        advanced_plotting = False
    log.info(f"advanced plotting = {advanced_plotting}")
    
    try:
        extended_eval = cfg.extended_eval
    except:
        extended_eval = False
    log.info(f"extended eval = {extended_eval}")

    try:
        sort_simple = cfg.sort_simple
    except:
        sort_simple = False
    log.info(f"Sort simple data = {sort_simple}")

    try:
        sort_reverse = cfg.sort_reverse
    except:
        sort_reverse = False
    log.info(f"Sort reverse data = {sort_reverse}")

    try:
        sort_index_hints = cfg.sort_index_hints
    except:
        sort_index_hints = False
    log.info(f"Sort index hint data = {sort_index_hints}")


    logic_func, name, max_size = grid_logic(cfg)

    try:
        mul = cfg.mul
        def logic_func_for_mul(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_mul
        name = '_large'
        max_size = 25+1
    except:
        mul = False
    log.info(f"mul = {mul}")
    try:
        pos_arth = cfg.pos_arth
        def logic_func_for_pos(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_pos
        name = '_large'
        max_size = 25+1
    except:
        pos_arth = False
    log.info(f"pos_arth = {pos_arth}")
    try:
        pos_arth_ood = cfg.pos_arth_ood
        def logic_func_for_pos_ood(data_size_1, data_size_2):
            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
        logic_func = logic_func_for_pos_ood
        name = '_ood_only'
        max_size = 40+1
    except:
        pos_arth_ood = False
    log.info(f"pos_arth_ood = {pos_arth_ood}")
    try:
        pos_arth_add_rev = cfg.pos_arth_add_rev
        def logic_func_for_pos_add_rev(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_pos_add_rev
        name = '_large'
        max_size = 25+1
    except:
        pos_arth_add_rev = False
    log.info(f"pos_arth_add_rev = {pos_arth_add_rev}")
    try:
        pos_arth_add_rev_ood = cfg.pos_arth_add_rev_ood
        def logic_func_for_pos_add_rev_ood(data_size_1, data_size_2):
            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
        logic_func = logic_func_for_pos_add_rev_ood
        name = '_ood'
        max_size = 40+1
    except:
        pos_arth_add_rev_ood = False
    log.info(f"pos_arth_add_rev_ood = {pos_arth_add_rev_ood}")
    try:
        pos_arth_add_forward = cfg.pos_arth_add_forward
        def logic_func_for_pos_add_forward(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_pos_add_forward
        name = '_large'
        max_size = 25+1
    except:
        pos_arth_add_forward = False
    log.info(f"pos_arth_add_forward = {pos_arth_add_forward}")
    try:
        pos_arth_add_forward_ood = cfg.pos_arth_add_forward_ood
        def logic_func_for_pos_add_forward_ood(data_size_1, data_size_2):
            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
        logic_func = logic_func_for_pos_add_forward_ood
        name = '_ood'
        max_size = 40+1
    except:
        pos_arth_add_forward_ood = False
    log.info(f"pos_arth_add_forward_ood = {pos_arth_add_forward_ood}")

    try:
        no_carries = cfg.no_carries
    except:
        no_carries = False
    log.info(f"no carries = {no_carries}")

    try:
        remove_padding = cfg.remove_padding
    except:
        remove_padding = True
    log.info(f"remove padding = {remove_padding}")

    try:
        add_random_pad = cfg.add_random_pad
    except:
        add_random_pad = False
    log.info(f"add_random_pad padding = {add_random_pad}")

    try:
        token_limit = cfg.token_limit
    except:
        token_limit = 20
    log.info(f"token limit = {token_limit}")

    try:
        crammed = cfg.crammed
    except:
        crammed = False

    try:
        rouge = cfg.rouge
    except:
        rouge = False
    log.info(f"rouge = {rouge}")

    cfg_data_sources_values_list = list(cfg.data.sources.values())[0]
    if cfg_data_sources_values_list["provider"] == "arithmetic":
        tokenizer = get_tokenizer(cfg_data_sources_values_list["tokenizer_type"])

        # cramming.data.pretraining_preparation.CustomCharLevelTokenizer()
        # if cfg_data_sources_values_list["tokenizer_type"] == "white_space":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerWithWhiteSpace()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "pad":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPadding()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "base_100":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPadding_Base100()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "base_1000":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPadding_Base1000()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "mod":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPaddingForMod()
        if cfg_data_sources_values_list["tokenizer_type"] == "del":
            tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForDelete()
    else: 
        log.info("exiting as this is only for arithmetic")
    vocab = tokenizer.ids_to_tokens
    EOS_token = tokenizer._convert_token_to_id(tokenizer.eos_token)
    PAD_token = tokenizer._convert_token_to_id(tokenizer.pad_token)
    assert PAD_token == 0

    # Load model
    if 'alpha' not in cfg_arch:
        cfg_arch['alpha'] = 1.0

    model = cramming.construct_model(cfg_arch, tokenizer).to(device)
    model = cramming.backend.load_model_checkpoint(model, model_file)
    model.to(device)
    model.eval()

    try:
        temp = cfg.temp
    except:
        temp = 1.0
    log.info(f"temperature = {temp}")

    try:
        greedy = cfg.greedy
    except:
        greedy = True
    log.info(f"greedy = {greedy}, note: if greedy = True this overrides any temperature arguments")
    ## Greedy deconding will overide any temperature arguments

    try:
        max_size_given = cfg.max_size_given
    except:
        max_size_given = None

    if max_size_given is not None:
        max_size = max_size_given

    # Grid plots - grid search from 1x1 to 12x12 data
    data_sizes = list(range(1, max_size))
    acc_grid = np.zeros((len(data_sizes),len(data_sizes)))
    start_ind_1 = 0
    start_ind_2 = 0
    tuple_method = False
    completed_one = False
    if "big_eval" in name:
        tuple_method = True
        # go up two layers and search for grid
        try:
            with open(f"../../accs_grid_quick{name}.json", 'r') as file:
                data = json.load(file)
            start_ind_1 = data[1]
            start_ind_2 = data[2]
            acc_grid = np.array(data[0])
            log.info("loaded grid from previous run")
        except:
            pass


    try:
        start_ind_1_given = cfg.start_ind_1_given
    except:
        start_ind_1_given = None

    if start_ind_1_given is not None:
        start_ind_1 = start_ind_1_given


    try:
        start_ind_2_given = cfg.start_ind_2_given
    except:
        start_ind_2_given = None

    if start_ind_2_given is not None:
        start_ind_2 = start_ind_2_given


        
    os.makedirs("outputs", exist_ok=True)

    all_outputs_folder_path = f"../../all_outputs_max_recurrence={cfg_arch.maximal_recurrence_in_eval}"
    os.makedirs(all_outputs_folder_path, exist_ok=True)

    if not extended_eval:
        for data_size_1 in data_sizes:
            for data_size_2 in data_sizes:
                proceed = False
                if data_size_1 >= start_ind_1 or data_size_2 >= start_ind_2:
                    proceed = True

                if not proceed:
                    continue

                # check if done
                # if done it will be done and saved in f"../../acc_for_{data_size_1}_{data_size_2}.txt"
                if os.path.exists(f"{all_outputs_folder_path}/top_half_intersection_acc_for_{data_size_1}_{data_size_2}.txt"):
                    with open(f"{all_outputs_folder_path}/acc_for_{data_size_1}_{data_size_2}.txt", 'r') as file:
                        acc = float(file.read())
                    acc_grid[data_size_1-1, data_size_2-1] = acc
                    continue

                if logic_func(data_size_1, data_size_2):
                    completed_one = True
                    log.info(f"Starting iteration in grid eval for size: {data_size_1} and {data_size_2}")
                    file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_padded_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset"
                    if sort_simple:
                        file_path = f"../../../../data/arithmetic_data/sort_simple/sort_uniform_distribution_sort_basic_max_digits_n_{data_size_1}_max_length_m_{data_size_2}_200_p_00/hf_tokenized_dataset"
                    if sort_reverse:
                        file_path = f"../../../../data/arithmetic_data/sort_reverse/sort_uniform_distribution_sort_basic_max_digits_n_{data_size_1}_max_length_m_{data_size_2}_200_p_00_reverse_all/hf_tokenized_dataset"
                    if sort_index_hints:
                        file_path = f"../../../../data/arithmetic_data/sort_index_hints/sort_uniform_distribution_sort_basic_max_digits_n_{data_size_1}_max_length_m_{data_size_2}_200_p_00_reverse_all_with_index_hints_circular/hf_tokenized_dataset"

                    tokenized_dataset = datasets.load_from_disk(file_path)["test"]
                    data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=100, shuffle=False)

                    correct_total = 0
                    all_total = 0
                    top_1_total = 0
                    top_half_total = 0
                    top_half_intersection_accuracy_total = 0
                    for batch in data_loader:
                        input_ids = batch["input_ids"]
                        input_ids = torch.stack(input_ids).to(device)
                        input_ids = torch.transpose(input_ids, 0, 1)

                        all = 0
                        correct = 0
                        top_1 = 0
                        top_half = 0
                        top_half_intersection_acc = 0
                        for i in range(len(input_ids)):
                            example = input_ids[i]
                            # print("All : ", tokenizer.decode(example.tolist()))
                            equals_token = tokenizer._convert_token_to_id("=")
                            # print(example)
                            # print(equals_token)
                            equals_indices = torch.where(example == equals_token)[0].item()
                            # print(equals_indices)
                            question = example[:equals_indices + 1]
                            answer = example[equals_indices + 1:]
                            # print("Question : ", tokenizer.decode(question.tolist()))
                            # print("Answer : ", tokenizer.decode(answer.tolist()))
                            # print("Question : ", question)
                            # add first dimension of batch to question and answer
                            question = question.unsqueeze(0)

                            local_token_limit = int(len(answer) * 2)
                            predicted_ids = model._generate(question,
                                                            token_limit=local_token_limit,
                                                            temperature=temp,
                                                            steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval,
                                                            greedy=greedy, quick=True)
                            predicted_ids = predicted_ids.squeeze()

                            eos_token = tokenizer._convert_token_to_id(tokenizer.eos_token)
                            eos_indices = torch.where(answer == eos_token)[0].item()
                            answer = answer[:eos_indices]

                            predicted_ids = predicted_ids[:len(answer)]
                            # print(answer)
                            # print(predicted_ids)
                            # print(tokenizer.decode(answer.tolist()))
                            # print(tokenizer.decode(predicted_ids.tolist()))
                            # print("__________________________")

                            if torch.equal(predicted_ids, answer):
                                correct += 1

                            top_1_target = answer[0]
                            top_1_predicted = predicted_ids[0]
                            if torch.equal(top_1_target, top_1_predicted):
                                top_1 += 1

                            half_len = len(answer)//2 + 1
                            top_half_target = answer[:half_len]
                            top_half_predicted = predicted_ids[:half_len]
                            if torch.equal(top_half_target, top_half_predicted):
                                top_half += 1

                            # convert it into set and then see how many of them are present
                            top_half_target_set = set(top_half_target.tolist())
                            top_half_predicted_set = set(top_half_predicted.tolist())
                            # find the intersection
                            intersection = top_half_target_set.intersection(top_half_predicted_set)
                            top_half_intersection_acc += len(intersection) / len(top_half_target_set)


                            all += 1

                        correct_total += correct
                        top_1_total += top_1
                        top_half_total += top_half
                        all_total += all
                        top_half_intersection_accuracy_total += top_half_intersection_acc


                    acc = correct_total / all_total
                    acc_top_1 = top_1_total / all_total
                    acc_top_half = top_half_total / all_total
                    acc_top_half_intersection = top_half_intersection_accuracy_total / all_total

                    log.info(f"accuracy for data that has numbers "
                             f"with maximum number of digits as {data_size_1} , "
                             f"and the array of length {data_size_2} is {acc * 100}")
                    log.info(f"Top 1 accuracy for data that has numbers "
                             f"with maximum number of digits as {data_size_1} , "
                             f"and the array of length {data_size_2} is {acc_top_1 * 100}")
                    log.info(f"Top half accuracy for data that has numbers "
                             f"with maximum number of digits as {data_size_1} , "
                             f"and the array of length {data_size_2} is {acc_top_half * 100}")
                    log.info(f"Top half intersection accuracy for data that has numbers "
                             f"with maximum number of digits as {data_size_1} , "
                             f"and the array of length {data_size_2} is {acc_top_half_intersection * 100}")

                    question = tokenizer.decode(question.squeeze().tolist())
                    answer = tokenizer.decode(answer.tolist())
                    predicted = tokenizer.decode(predicted_ids.tolist())
                    log.info(f"For example : sort {question} for which the answer is {answer} , "
                             f"and the predicted is {predicted}")
                    acc_grid[(data_size_1-1), (data_size_2-1)] = acc * 100
                    # save all in case of crash

                    with open(f"{all_outputs_folder_path}/acc_for_{data_size_1}_{data_size_2}.txt", "w") as file:
                        file.write(f"{acc * 100}")
                    with open(f"{all_outputs_folder_path}/top_1_acc_for_{data_size_1}_{data_size_2}.txt", "w") as file:
                        file.write(f"{acc_top_1 * 100}")
                    with open(f"{all_outputs_folder_path}/top_half_acc_for_{data_size_1}_{data_size_2}.txt", "w") as file:
                        file.write(f"{acc_top_half * 100}")
                    with open(f"{all_outputs_folder_path}/top_half_intersection_acc_for_{data_size_1}_{data_size_2}.txt", "w") as file:
                        file.write(f"{acc_top_half_intersection * 100}")


        log.info(f"acc grid: {acc_grid}")

        with open(f"accs_grid_quick_{start_ind_1}_{start_ind_2}_{max_size}.json", "w") as file:
            json.dump(acc_grid.tolist(), file)

        # Grid plots - one for accs one for contains
        grid_plotter(acc_grid, name=f"{start_ind_1}_{start_ind_2}_{max_size}")
        grid_plotter(acc_grid, name=f"{start_ind_1}_{start_ind_2}_{max_size}", extra_path=all_outputs_folder_path)
        exit()

    if extended_eval:
        log.info("starting extended eval")
        # this is hard coded for reverse all, addition past 100x100 grid, removing the padding
        accs = dict()
        start = 101
        print(os.getcwd())
        old_data_path = None
        for root, dirs, files in os.walk("../.."):
            if "over_100.json" in files:
                old_data_path = os.path.join(root, "over_100.json")

        if old_data_path is not None:
            with open(old_data_path, 'r') as file:
                data = json.load(file)
                
            accs = {int(k): v for k, v in data.items()}
            start = max(accs.keys()) + 1
            log.info(f"loaded accs: {accs}")
            log.info(f"starting from: {start}")

        for data_size in range(start,171):
            log.info(f"Extended eval {data_size}")
            correct_total = 0
            file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized_over_100/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_exact_seed_42/hf_tokenized_dataset"
            tokenized_dataset = datasets.load_from_disk(file_path)["test"]
            data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=100, shuffle=False)
            equals_tensor = data_size+data_size+6

            for batch in data_loader:
                tokenized_prompts = batch["input_ids"][:equals_tensor]
                tokenized_prompts = torch.stack(tokenized_prompts).to(device)
                tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)
                tokenized_answers = batch["input_ids"][equals_tensor:]
                tokenized_answers = torch.stack(tokenized_answers).to(device)
                tokenized_answers = torch.transpose(tokenized_answers, 0, 1)



                num1 = tokenized_prompts[:,:data_size]
                op = tokenized_prompts[:,data_size+1:data_size+2]
                num2 = tokenized_prompts[:,data_size+3:data_size+data_size+3]
                equals = tokenized_prompts[:,data_size+data_size+4:data_size+data_size+5]
                tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)

                predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=greedy, quick=True)
                predicted_ids = torch.transpose(predicted_ids, 0, 1) # add a batch dim

                eval_tensor = predicted_ids.clone()
                input_tensor_EOS = (eval_tensor == EOS_token).int()
                indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)
                mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]
                eval_tensor[mask] = PAD_token
                elementwise_equal = torch.eq(eval_tensor, tokenized_answers)
                
                rows_equal = torch.all(elementwise_equal, dim=1)
                num_equal_rows = torch.sum(rows_equal).item()
                correct_total += (num_equal_rows/tokenized_prompts.shape[0])
                log.info(f"accuracy for {data_size}, {data_size}: {num_equal_rows} = {correct_total*100}%")

                # combine the prompts and outputs
                complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)
                tokens_list = complete_lines.tolist()
                decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens
                log.info(f"example for {data_size}, {data_size}: {decoded_batch[0]}")
                # save the answers down so we don't eval twice ever
                with open(f"over_100.json", 'w') as json_file:
                    json.dump(accs, json_file)

            accs[data_size] = correct_total

    # Run prompter to get thinking plots too -- one for each data size
    if (not crammed) and (not advanced_plotting) and (not extended_eval):
        if tuple_method: # i.e. for the big eval only make the thinking plots for the first one
            try:
                big_eval_step_1 = cfg.big_eval_step_1
            except:
                big_eval_step_1 = False
            if big_eval_step_1:
                token_limit = 105
            else:
                exit()
        vocab = tokenizer.get_vocab()
        sorted_tokenizer_keys = sorted(vocab, key=lambda x: vocab[x])
        data_sizes = data_sizes = list(range(1, max_size))
        for data_size in data_sizes:
            file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
            if no_carries:
                file_path = f"../../../../data/arithmetic_data/+_no_carries_grid_eval_dataset/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
            elif reverse_inputs:
                file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
                reverse_input_in_tokenise = False
                reverse_input_in_answer = False
            elif mul:
                file_path = f"../../../../data/arithmetic_data/x_grid_eval_dataset_2/x_n_{data_size}_m_{data_size}_examples_100_diff_lens_exact_seed_91.txt"
                reverse_input_in_tokenise = True
                reverse_input_in_answer = True
            elif pos_arth or pos_arth_ood:
                    file_path = f"../../../../data/arithmetic_data/pos_arith_eval/pos_arith_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_n_{data_size}.txt"
            elif pos_arth_add_rev or pos_arth_add_rev_ood:
                file_path = f"../../../../data/arithmetic_data/pos_arith_add_rev_eval/pos_arith_add_rev_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_m_{data_size}.txt"
            elif pos_arth_add_forward or pos_arth_add_forward_ood:
                file_path = f"../../../../data/arithmetic_data/pos_arith_add_forward_eval/pos_arith_add_forward_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_m_{data_size}.txt"
            with open(file_path, "r") as file:
                for line in file:
                    prompt, tokenized_prompt, answer = tokenize_and_split_line(line.strip(), tokenizer, device, reverse_inputs=reverse_input_in_tokenise, pad_zeros=pad_zeros, remove_padding=remove_padding)
                    tracking, verbose_completion, logits = verbose_prompt_eval(tokenized_prompt, answer, model, tokenizer, cfg_arch, token_limit=token_limit, temp=temp, greedy=greedy)
                    if reverse_input_in_answer:
                        answer = str(answer)[::-1]
                    log.info(f"Prompt: {prompt}")
                    log.info(f"Answer: {answer}")
                    log.info(f"Verbose completion: {verbose_completion}")

                    data = {
                        "gen_tokens": tracking,
                        "logits": logits,
                    }
                    multi_plotter(data, data_size, prompt, answer, verbose_completion, sorted_tokenizer_keys, n_locs=token_limit, n_times=cfg_arch.maximal_recurrence_in_eval)
                    # thinking_plotter(tracking, data_size, prompt, answer, verbose_completion)
                    break
                    
    elif advanced_plotting:
        vocab = tokenizer.get_vocab()
        sorted_tokenizer_keys = sorted(vocab, key=lambda x: vocab[x])
        data_sizes = data_sizes = list(range(20, max_size))
        for data_size in data_sizes:
            
            file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
            reverse_input_in_tokenise = False
            reverse_input_in_answer = False
            with open(file_path, "r") as file:
                for line in file:
                    prompt, tokenized_prompt, answer = tokenize_and_split_line(line.strip(), tokenizer, device, reverse_inputs=reverse_input_in_tokenise, pad_zeros=pad_zeros, remove_padding=remove_padding)
                    em_acc = []
                    for rec in range(1,cfg.max_rec+1):
                        cfg_arch.maximal_recurrence_in_eval = rec
                        tracking, verbose_completion, logits = verbose_prompt_eval(tokenized_prompt, answer, model, tokenizer, cfg_arch, token_limit=token_limit, temp=temp, greedy=greedy)
                        answer = list(answer)
                        verbose_up_to_eos = verbose_completion[:next((i for i, x in enumerate(verbose_completion) if x == '[EOS]'), len(verbose_completion))]
                        count = sum(1 for x, y in zip(answer, verbose_up_to_eos) if x == y)/len(answer)
                        em_acc.append(count)
                    # plot
                    recurrences = range(1, cfg.max_rec+1)
                    plt.figure(figsize=(8, 6))
                    plt.plot(recurrences, em_acc, marker='o', linestyle='-', color='b')
                    plt.title(f'EM Accuracy by Number of Recurrences, {prompt}{"".join(verbose_up_to_eos)}')
                    plt.xlabel('Number of Recurrences')
                    plt.ylabel('EM Accuracy')
                    plt.ylim(0, 1.1)
                    plt.xticks(recurrences)
                    plt.savefig(f"em_acc_plot_{data_size}", bbox_inches='tight')
                    plt.clf()
                    break

    log.info("Eval complete")

@hydra.main(config_path="cramming/config", config_name="cfg_eval", version_base="1.3")
def launch(cfg):
    log.info("calling main launch")
    cfg = cramming.utils.pathfinder(cfg)
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    main(cfg)

if __name__ == "__main__":
    launch()