import logging
import hydra
from omegaconf import OmegaConf
import cramming
import torch
from safetensors.torch import load_file
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
import re
import pandas as pd
import datasets
import os
from typing import List, Dict
from cramming.data.tokenizer_preparation import get_tokenizer
from cramming.data.arithmetic_tokenizers import CustomCharLevelTokenizerForAddingPadding_Base100, CustomCharLevelTokenizerForAddingPadding_Base1000
import random
from docx import Document
from docx.shared import RGBColor

log = logging.getLogger(__name__)

def load_checkpoint(model, file, device):
    model_state = load_file(file, device=device)
    if "encoder.embedding.word_embedding.weight" not in model_state:
        # Hack to save space when saving the model, more clever though would be save the right one in the first place
        model_state["encoder.embedding.word_embedding.weight"] = model_state["decoder.weight"]
    try:
        sanitized_state = {}
        for k, v in model_state.items():
            if k.startswith("module."):
                k = k[7:]
            sanitized_state[k] = v
        model.load_state_dict(sanitized_state, strict=True)
        log.info("finished loading state dict")
    except RuntimeError as e:
        log.info(f"State dict difference is {str(e).split('Error(s) in loading state_dict for')[1]}... Ok?")
        model.load_state_dict(sanitized_state, strict=False)
    return model.to(device)

def line_plotter(accs, type="accs", large=False, ood_only=False):
    x_ticks = list(accs.keys())
    values = [i * 100 for i in list(accs.values())]
    plt.plot(x_ticks, values, marker='o')

    plt.xlabel("Number of Digits")
    y_label = "Accuracy" if (type=="accs" or type=="reverse") else "Percentage with answer contained in output"
    plt.ylabel(y_label)
    plt.title(y_label)
    plt.ylim(-1, 101)
    plt.xticks(x_ticks)
    plt.tight_layout()
    plt.savefig(f"{type}_{'large_' if large else ''}{'ood_only_' if ood_only else ''}line_plot", bbox_inches='tight')
    plt.clf()

def grid_plotter(data, type="accs", name='_large', extra_path=None):
    data = np.array(data)*100
    df = pd.DataFrame(data)

    # Create the heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(df, cmap="YlGnBu", fmt=".1f", annot_kws={'size': 8,'rotation':0})
    
    # Customize the plot
    plt.title("Accuracy - percetange, rounded to 1dp")
    plt.ylabel("1st Number Length")
    plt.xlabel("2nd Number Length")
    size = data.shape[0]
    plt.xticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))
    plt.yticks(np.arange(0.5, size+0.5, 1), labels=np.arange(1, size+1, 1))

    if extra_path is not None:
        plt.savefig(f"{extra_path}{type}{name}_grid_plot", bbox_inches='tight')
    else:
        plt.savefig(f"{type}{name}_grid_plot", bbox_inches='tight')
    plt.clf()

def default_plotting(ax, x_range, fs=18):
    ax.set_xlabel("Time (in iterations)", fontsize=fs)
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xticks(x_range)
    ax.set_xticklabels(x_range, fontsize=fs)

def plot_completion(data, ax, tokens, fs=18):
    ax.plot(data, linewidth=3)
    ax.set_yticks(list(range(len(tokens))))
    ax.set_yticklabels(tokens, fontsize=fs)

def plot_confidence(data_array, ax, n_times, tokens, fs=18):
    x_labels = list(range(n_times))
    bottom = np.zeros(n_times)
    colors = plt.colormaps.get_cmap('tab20')
    for j, row in enumerate(data_array.T):
        ax.bar(x_labels, row, bottom=bottom, label=tokens[j], color=colors(j))
        bottom += row  # Update bottom positions for the next category
    ax.legend()

def compare_lists_up_to_eos(list1, list2):
    eos_index1 = list1.index("[EOS]") if "[EOS]" in list1 else len(list1)
    eos_index2 = list2.index("[EOS]") if "[EOS]" in list2 else len(list2)
    min_eos_index = min(eos_index1, eos_index2)
    return list1[:min_eos_index] == list2[:min_eos_index]

def multi_plotter(data_dict: Dict, num: int, prompt: str, answer_str: str, completion: List[str], tokens: List[str] = None, n_locs: int = 0, n_times: int = 0, truncate:bool=True):
    answer = [*answer_str]
    answer.append("[EOS]")
    if truncate:
        n_locs = min(len(completion), len(answer))
    fs = 18
    sns.set_theme(style="darkgrid")
    fig, axs = plt.subplots(len(data_dict.items()), n_locs, figsize=(n_locs * 4, 16), sharey=False)
    # tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "$-$", "$\\times$", "=", ' ']

    label_freq = 2 if n_times <= 10 else 5
    x_range = list(range(0, n_times + 2, label_freq))  # for x labels
    row = 0
    for name, data in data_dict.items():
        for i in range(n_locs):
            default_plotting(axs[row, i], x_range, fs)

            title = f"Loc {i}"
            ans = f'ans: {answer[i]}' if i < len(answer) else ''
            comp = f'guess: {completion[i]}' if i < len(completion) else ''
            note = f'{ans} {comp}'
            title = f"{title} {f'({note})' if len(note) > 1 else ''}"
            if name == "gen_tokens":
                plot_completion(data[i], axs[row, i], tokens, fs)
                y_label = "Output Token"
            elif name == "logits":
                data_array = np.array(data[i])
                plot_confidence(data_array, axs[row, i], n_times, tokens, fs=fs)
                y_label = "Logits"
            else:
                print(f"INVALID DATA: {name}")
                return
            axs[row, i].set_title(title, fontsize=fs)
            axs[row, 0].set_ylabel(y_label, fontsize=fs)
        row += 1
    # fig.suptitle(f"Final Output: {''.join([tokens[s] for s in arr[:, -1]])}", fontsize=fs+12)
    is_correct = f'<Correct>' if compare_lists_up_to_eos(completion, answer) else f'<Wrong>'  # remove EOS from completion check

    fig.suptitle(f"Prompt: {prompt} Answer: {''.join(answer)} Output: {''.join(completion)} {is_correct}", fontsize=fs + 36)
    plt.tight_layout()
    plt.savefig(f"combined_plot_{num}", bbox_inches='tight')

def thinking_plotter(arr, num, prompt, answer, completion):
    n_locs, n_times = arr.shape

    fs = 18
    sns.set_theme(style="darkgrid")
    fig, axs = plt.subplots(1, n_locs, figsize=(n_locs*4, 8), sharey=False)
    tokens = ["[PAD]", "[UNK]", "[BOS]", "[EOS]", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "$-$", "$\\times$", "="]

    for i in range(n_locs):
        axs[i].plot(arr[i], linewidth=3)
        axs[i].set_xticks([0, 2, 4, 6, 8, 10])
        axs[i].set_xticklabels([0, 2, 4, 6, 8, 10], fontsize=fs)
        axs[i].spines["top"].set_visible(False)
        axs[i].spines["bottom"].set_visible(False)
        axs[i].spines["left"].set_visible(False)
        axs[i].spines["right"].set_visible(False)
        axs[i].set_title(f"Location {i}", fontsize=fs)
        axs[i].set_yticks(list(range(len(tokens))))
        axs[i].set_yticklabels(tokens, fontsize=fs)
        axs[i].set_xlabel("Time (in iterations)", fontsize=fs)
    axs[0].set_ylabel("Output Token", fontsize=fs)
    fig.suptitle(f"Prompt: {prompt} Answer: {answer} Output: {completion}", fontsize=fs+12)
    plt.tight_layout()
    plt.savefig(f"iteration_plot_{num}", bbox_inches='tight')
    plt.clf()

@torch.inference_mode()
def model_gen(model, *inputs, token_limit, temperature, **kwargs): # through forward call
    device_inputs = inputs[0] # tensor([[ 7,  1, 14,  1,  7,  1, 17]], device='cuda:0') i.e. removes it from tuple
    predicted_ids = []

    for gen_idx in range(token_limit):
        logits = model(device_inputs)["logits"]
        predicted_token = torch.multinomial(torch.softmax(logits * temperature, dim=-1), 1)
        device_inputs = torch.cat([device_inputs, predicted_token], dim=-1)
        predicted_ids += [predicted_token]
        outputs = torch.cat(predicted_ids, dim=-1)
    return outputs

def verbose_prompt_eval(device_inputs, answer, model, tokenizer, cfg_arch, token_limit=20, temp=0.7, crammed=False, greedy=False):
    """
    Tracks steps during generation
    """
    predicted_ids, tracking, logits = model._generate(device_inputs, token_limit=token_limit, temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, track_steps=True, greedy=greedy) # auto takes cfg_arch.maximal_recurrence_in_eval
    # decoded_completion = tokenizer.decode(predicted_ids[0].tolist())  # drop batch dim before decoding
    verbose_decoded_completion = [tokenizer._convert_id_to_token(token_id) for token_id in predicted_ids[0].tolist()]
    # verbose_decoded_completion = "".join(verbose_decoded_completion)
    np_tmp = []
    for j in range(0,len(tracking)):
        row = tracking[j]
        row_temp = [tensor.cpu().item() for tensor in row]
        np_tmp.append(row_temp)
    np_tmp = np.array(np_tmp)
    return np_tmp, verbose_decoded_completion, logits

def rouge_eval(tf, tf_text, device_inputs, answer, model, tokenizer, cfg_arch, token_limit=20, temp=0.7, crammed=False, greedy=False):
    """
    returns rouge-l tensorflow score
    """
    if crammed:
        predicted_ids = model_gen(model, device_inputs, token_limit=token_limit, temperature=temp)
    else:
        predicted_ids = model._generate(device_inputs, token_limit=token_limit, temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=greedy) # auto takes cfg_arch.maximal_recurrence_in_eval

    # removes everything after first EOS
    eos_id = tokenizer._convert_token_to_id(tokenizer.eos_token)
    index_of_eos = (predicted_ids == eos_id).nonzero(as_tuple=False)
    if index_of_eos.numel() > 0:
        # Truncate the tensor up to the first occurrence of 3
        predicted_ids= predicted_ids[:, :index_of_eos[0, 1] + 1]

    decoded_completion = tokenizer.decode(predicted_ids[0].tolist())  # drop batch dim before decoding
    decoded_completion = list(decoded_completion)

    gpu_device = '/gpu:0'
    with tf.device(gpu_device):
        hyp = tf.cast(tf.ragged.constant([decoded_completion]), dtype=tf.string)
        answer = list(answer)
        ref = tf.cast(tf.ragged.constant([answer]), dtype=tf.string)
        f, p, r = tf_text.metrics.rouge_l(hyp, ref)
    return f.numpy().tolist(), p.numpy().tolist(), r.numpy().tolist()

def reverse_numbers(input_string):
    pattern = r'\d+'
    def reverse_match(match):
        return match.group(0)[::-1]
    result = re.sub(pattern, reverse_match, input_string)
    return result

def pad_numbers(input_string, min_length=12):
    pattern = r'\d+'
    def pad(match):
        padded_number = match.group(0).rjust(min_length, '0')
        return padded_number

    result = re.sub(pattern, pad, input_string)
    return result

def tokenize_and_split_line(line, tokenizer, device, reverse_inputs=False, pad_zeros=0, remove_padding=False):
    try:
        prompt, answer = line.rsplit(" ", 1)
        prompt = prompt + " " # adding the removed space back onto prompt
    except:
        equal_sign_index = line.find('=')
        prompt = line[:equal_sign_index].replace('+', ' + ').strip() + " = "
        answer = line[equal_sign_index + 1:].strip()
    
    if isinstance(tokenizer, CustomCharLevelTokenizerForAddingPadding_Base1000) or isinstance(tokenizer, CustomCharLevelTokenizerForAddingPadding_Base100):
        answer = tokenizer._tokenize(answer)
        answer = ''.join(answer)
    if pad_zeros > 0:
        prompt = pad_numbers(prompt, pad_zeros)
    if reverse_inputs:
        prompt = reverse_numbers(prompt)
    if remove_padding:
        prompt = prompt.replace(" ","")
    tokenized_inputs = torch.as_tensor(tokenizer(prompt)["input_ids"], dtype=torch.long)[None, :]
    device_inputs = tokenized_inputs.to(device)
    return prompt, device_inputs, answer

def write_list_to_word_file(lists, save_file_name='colored_text.docx'):
    document = Document()

    for lst in lists:
        paragraph = document.add_paragraph()
        for item in lst:
            try:
                if 'PAD' in item[0]:
                    item = (' ', item[1])
            except:
                print(item)
                print(lst)
                exit()

            if item[1] == "KEEP":
                run = paragraph.add_run(item[0])
                font = run.font
                font.color.rgb = RGBColor(0, 0, 0)  # black color
            elif item[1] == "CORRECT":
                run = paragraph.add_run(item[0])
                font = run.font
                font.color.rgb = RGBColor(0, 128, 0)  # green color
            else:
                run = paragraph.add_run(item[0])
                font = run.font
                font.color.rgb = RGBColor(255, 0, 0)  # red color
            run.add_text(" ")
        document.add_paragraph()  # Add a new line between lists

    document.save(save_file_name)

def gdm_index_hints_helper(num, tokenizer):
    char_set = tokenizer.char_set
    shape1 = num.shape[1]
    for i in range(shape1):
        this_char_token = tokenizer._convert_token_to_id(char_set[i])
        char_to_insert = this_char_token * torch.ones((num.shape[0], 1), dtype=num.dtype, device=num.device)
        num = torch.cat((num[:,:(2*i)], char_to_insert, num[:,(2*i):]), dim=1)
    return num

def grid_logic(cfg):
    # origional testing
    def logic_func_large(data_size_1, data_size_2):
        return (data_size_1 <= 23 or data_size_2 <=23)
    logic_func = logic_func_large
    name = '_large'
    max_size = 23+1
    try:
        large = cfg.large
    except:
        large = True
    try:
        ood_only = cfg.ood_only
        def logic_func_ood(data_size_1, data_size_2):
            return (data_size_1 >=24 or data_size_2 >=24) and (data_size_1 <= 30 or data_size_2 <=30)
        logic_func = logic_func_ood
        name = '_ood_only'
        max_size = 30+1
    except:
        ood_only = False
    try:
        up_to_40 = cfg.up_to_40
        def logic_func_40(data_size_1, data_size_2):
            return (data_size_1 >=31 or data_size_2 >=31) and (data_size_1 <=40 or data_size_2 <=40)
        logic_func = logic_func_40
        name = '_up_to_40'
        max_size = 40+1
    except:
        up_to_40 = False
    try:
        up_to_50 = cfg.up_to_50
        def logic_func_50(data_size_1, data_size_2):
            return (data_size_1 >=41 or data_size_2 >=41) and (data_size_1 <=50 or data_size_2 <=50)
        logic_func = logic_func_50
        name = '_up_to_50'
        max_size = 50+1
    except:
        up_to_50 = False

    # checkerboarding: for the large eval we can checkerboard:
    try:
        checkerboard = cfg.checkerboard
    except:
        checkerboard = None

    if checkerboard is not None:
        if checkerboard == 'even':
            def checkerboard_even(data_size_1, data_size_2):
                return ((data_size_1+data_size_2)%2 ==0)
            checkerboard_func = checkerboard_even
            checkerboard_str = "_even"
        elif checkerboard == 'odd':
            def checkerboard_odd(data_size_1, data_size_2):
                return ((data_size_1+data_size_2)%2 ==1)
            checkerboard_func = checkerboard_odd
            checkerboard_str = "_odd"
        else:
            print("checkerboard config not allowed")
            exit()
    else:
        def always_true(data_size_1, data_size_2):
            return True
        checkerboard_func = always_true
        checkerboard_str = ""


    # if we are testing up to 100, split into 10 steps each of approximately equal number of forward passes required
    try:
        big_eval_step_1 = cfg.big_eval_step_1 # 1 -> 46
        def logic_func_big_1(data_size_1, data_size_2):
            return (data_size_1 <= 46 and data_size_2 <= 46) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_1
        name = '_big_eval_1'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_1 = False
    try:
        big_eval_step_2 = cfg.big_eval_step_2 # 47 -> 58
        def logic_func_big_2(data_size_1, data_size_2):
            # return (data_size_1  == 50 and data_size_2 == 50) 
            # with padding = 67%
            # without padding = 48 %
            # random padding = 48%
            return (data_size_1 >=47 or data_size_2 >=47) and (data_size_1 <=58 and data_size_2 <=58) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_2
        name = '_big_eval_2'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_2 = False
    try:
        big_eval_step_3 = cfg.big_eval_step_3 # 59 -> 67
        def logic_func_big_3(data_size_1, data_size_2):
            return (data_size_1 >=59 or data_size_2 >=59) and (data_size_1 <=67 and data_size_2 <=67) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_3
        name = '_big_eval_3'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_3 = False
    try:
        big_eval_step_4 = cfg.big_eval_step_4 # 68 -> 74
        def logic_func_big_4(data_size_1, data_size_2):
            return (data_size_1 >=68 or data_size_2 >=68) and (data_size_1 <=74 and data_size_2 <=74) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_4
        name = '_big_eval_4'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_4 = False
    try:
        big_eval_step_5 = cfg.big_eval_step_5 # 75 -> 80
        def logic_func_big_5(data_size_1, data_size_2):
            return (data_size_1 >= 75 or data_size_2 >=75) and (data_size_1 <=80 and data_size_2 <=80) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_5
        name = '_big_eval_5'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_5 = False
    try:
        big_eval_step_6 = cfg.big_eval_step_6 # 81 -> 85
        def logic_func_big_6(data_size_1, data_size_2):
            return (data_size_1 >= 81 or data_size_2 >=81) and (data_size_1 <=85 and data_size_2 <=85) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_6
        name = '_big_eval_6'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_6 = False
    try:
        big_eval_step_7 = cfg.big_eval_step_7 # 86 -> 90
        def logic_func_big_7(data_size_1, data_size_2):
            return (data_size_1 >= 86 or data_size_2 >=86) and (data_size_1 <=90 and data_size_2 <=90) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_7
        name = '_big_eval_7'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_7 = False
    try:
        big_eval_step_8 = cfg.big_eval_step_8 # 91 -> 94
        def logic_func_big_8(data_size_1, data_size_2):
            return (data_size_1 >= 91 or data_size_2 >=91) and (data_size_1 <=94 and data_size_2 <=94) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_8
        name = '_big_eval_8'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_8 = False
    try:
        big_eval_step_9 = cfg.big_eval_step_9 # 95 -> 97
        def logic_func_big_9(data_size_1, data_size_2):
            return (data_size_1 >= 95 or data_size_2 >=95) and (data_size_1 <=97 and data_size_2 <=97) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_9
        name = '_big_eval_9'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_9 = False
    try:
        big_eval_step_10 = cfg.big_eval_step_10 # 98 -> 100
        def logic_func_big_10(data_size_1, data_size_2):
            return (data_size_1 >= 98 or data_size_2 >=98) and (data_size_1 <=100 and data_size_2 <=100) and checkerboard_func(data_size_1, data_size_2)
        logic_func = logic_func_big_10
        name = '_big_eval_10'+checkerboard_str
        max_size = 100+1
    except:
        big_eval_step_10 = False


    # boolean_list_precidence = [large, ood_only, up_to_40, up_to_50, big_eval_step_1, big_eval_step_2, big_eval_step_3, big_eval_step_4, big_eval_step_5]

    log.info(f"large = {large}")
    log.info(f"ood only = {ood_only}")
    log.info(f"up to 40 = {up_to_40}")
    log.info(f"up to 50 = {up_to_50}")
    log.info(f"big eval 1 = {big_eval_step_1}")
    log.info(f"big eval 2 = {big_eval_step_2}")
    log.info(f"big eval 3 = {big_eval_step_3}")
    log.info(f"big eval 4 = {big_eval_step_4}")
    log.info(f"big eval 5 = {big_eval_step_5}")
    log.info(f"big eval 6 = {big_eval_step_6}")
    log.info(f"big eval 7 = {big_eval_step_7}")
    log.info(f"big eval 8 = {big_eval_step_8}")
    log.info(f"big eval 9 = {big_eval_step_9}")
    log.info(f"big eval 10 = {big_eval_step_10}")
    log.info(f"the last true value in the above list will be run, mul and pos arith can take control after this")

    return logic_func, name, max_size

def main(cfg):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    local_checkpoint_folder = os.path.join(cfg.base_dir, cfg.name, "checkpoints")
    tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg.eval.checkpoint,
                                                                                local_checkpoint_folder,
                                                                                cfg.eval.arch_modifications)
    try:
        cfg_arch.mask_before_equals = cfg_arch.mask_before_equals
    except:
        cfg_arch.mask_before_equals = False

    try:
        cfg_arch.loss_reduction = cfg_arch.loss_reduction
    except:
        cfg_arch.loss_reduction = 'mean'

    try:
        cfg_arch.maximal_recurrence_in_eval = cfg.max_rec
    except:
        cfg_arch.maximal_recurrence_in_eval = 10
    log.info(f"cfg_arch.maximal_recurrence_in_eval changed to {cfg_arch.maximal_recurrence_in_eval}")

    try:
        cfg_arch.attention.max_length = cfg_arch.attention.max_length
    except:
        cfg_arch.attention.max_length = 0

    if 'forward_only_model_with_skip' not in cfg_arch: # trained before we corrected the typo
        try:
            cfg_arch.forward_only_model_with_skip = cfg_arch.forward_only_model_with_sikp
            cfg_arch.pop("forward_only_model_with_sikp")
        except:
            cfg_arch.forward_only_model_with_skip = False
    log.info(f"forward_only_model_with_skip = {cfg_arch.forward_only_model_with_skip}")

    try: # change max k value in recycle at run time, or set it for models trained before this
        cfg_arch.embedding.max_recycle_len = cfg_arch.embedding.max_recycle_len 
    except:
        cfg_arch.embedding.max_recycle_len = 100
    log.info(f"embedding.max_recycle_len = {cfg_arch.embedding.max_recycle_len}")

    try: # change pos embedding at run time
        cfg_arch.embedding.pos_embedding = cfg.embedding.pos_embedding
    except:
        cfg_arch.embedding.pos_embedding = cfg_arch.embedding.pos_embedding
    log.info(f"cfg_arch.embedding.pos_embedding = {cfg_arch.embedding.pos_embedding}")

    cfg_arch.throttle = False

    try:
        reverse_inputs = cfg.reverse_inputs
    except:
        reverse_inputs = False
    log.info(f"reverse inputs = {reverse_inputs}")

    try:
        pad_zeros = cfg.pad_zeros
    except:
        pad_zeros = 0
    log.info(f"pad zeros = {pad_zeros}")

    try:
        advanced_plotting = cfg.advanced_plotting
    except:
        advanced_plotting = False
    log.info(f"advanced plotting = {advanced_plotting}")
    
    try:
        extended_eval = cfg.extended_eval
    except:
        extended_eval = False
    log.info(f"extended eval = {extended_eval}")

    logic_func, name, max_size = grid_logic(cfg)

    try:
        mul = cfg.mul
        def logic_func_for_mul(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_mul
        name = '_large'
        max_size = 25+1
    except:
        mul = False
    log.info(f"mul = {mul}")
    try:
        pos_arth = cfg.pos_arth
        def logic_func_for_pos(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_pos
        name = '_large'
        max_size = 25+1
    except:
        pos_arth = False
    log.info(f"pos_arth = {pos_arth}")
    try:
        pos_arth_ood = cfg.pos_arth_ood
        def logic_func_for_pos_ood(data_size_1, data_size_2):
            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
        logic_func = logic_func_for_pos_ood
        name = '_ood_only'
        max_size = 40+1
    except:
        pos_arth_ood = False
    log.info(f"pos_arth_ood = {pos_arth_ood}")
    try:
        pos_arth_add_rev = cfg.pos_arth_add_rev
        def logic_func_for_pos_add_rev(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_pos_add_rev
        name = '_large'
        max_size = 25+1
    except:
        pos_arth_add_rev = False
    log.info(f"pos_arth_add_rev = {pos_arth_add_rev}")
    try:
        pos_arth_add_rev_ood = cfg.pos_arth_add_rev_ood
        def logic_func_for_pos_add_rev_ood(data_size_1, data_size_2):
            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
        logic_func = logic_func_for_pos_add_rev_ood
        name = '_ood'
        max_size = 40+1
    except:
        pos_arth_add_rev_ood = False
    log.info(f"pos_arth_add_rev_ood = {pos_arth_add_rev_ood}")
    try:
        pos_arth_add_forward = cfg.pos_arth_add_forward
        def logic_func_for_pos_add_forward(data_size_1, data_size_2):
            return (data_size_1 <= 25 or data_size_2 <= 25)
        logic_func = logic_func_for_pos_add_forward
        name = '_large'
        max_size = 25+1
    except:
        pos_arth_add_forward = False
    log.info(f"pos_arth_add_forward = {pos_arth_add_forward}")
    try:
        pos_arth_add_forward_ood = cfg.pos_arth_add_forward_ood
        def logic_func_for_pos_add_forward_ood(data_size_1, data_size_2):
            return (data_size_1 >= 26 or data_size_2 >=26) and (data_size_1 <=40 and data_size_2 <=40)
        logic_func = logic_func_for_pos_add_forward_ood
        name = '_ood'
        max_size = 40+1
    except:
        pos_arth_add_forward_ood = False
    log.info(f"pos_arth_add_forward_ood = {pos_arth_add_forward_ood}")

    try:
        no_carries = cfg.no_carries
    except:
        no_carries = False
    log.info(f"no carries = {no_carries}")

    try:
        remove_padding = cfg.remove_padding
    except:
        remove_padding = True
    log.info(f"remove padding = {remove_padding}")

    try:
        add_random_pad = cfg.add_random_pad
    except:
        add_random_pad = False
    log.info(f"add_random_pad padding = {add_random_pad}")

    try:
        del_ignore_after_EOS = cfg.del_ignore_after_EOS # gives EOS presidence over DEL
    except:
        del_ignore_after_EOS = True
    log.info(f"del_ignore_after_EOS = {del_ignore_after_EOS}")

    try:
        token_limit = cfg.token_limit
    except:
        token_limit = 20
    log.info(f"token limit = {token_limit}")

    try:
        crammed = cfg.crammed
    except:
        crammed = False

    try:
        rouge = cfg.rouge
    except:
        rouge = False
    log.info(f"rouge = {rouge}")

    use_del = False
    # import tokeniser
    cfg_data_sources_values_list = list(cfg.data.sources.values())[0]
    if cfg_data_sources_values_list["provider"] == "arithmetic":
        tokenizer = get_tokenizer(cfg_data_sources_values_list["tokenizer_type"])

        # cramming.data.pretraining_preparation.CustomCharLevelTokenizer()
        # if cfg_data_sources_values_list["tokenizer_type"] == "white_space":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerWithWhiteSpace()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "pad":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPadding()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "base_100":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPadding_Base100()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "base_1000":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPadding_Base1000()
        # elif cfg_data_sources_values_list["tokenizer_type"] == "mod":
        #     tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForAddingPaddingForMod()
        if cfg_data_sources_values_list["tokenizer_type"] == "del":
            # tokenizer = cramming.data.pretraining_preparation.CustomCharLevelTokenizerForDelete()
            use_del=True
            del_token = tokenizer._convert_token_to_id('D')
    else: 
        log.info("exiting as this is only for arithmetic")
    vocab = tokenizer.ids_to_tokens
    EOS_token = tokenizer._convert_token_to_id(tokenizer.eos_token)
    PAD_token = tokenizer._convert_token_to_id(tokenizer.pad_token)
    assert PAD_token == 0

    # Load model
    if 'alpha' not in cfg_arch:
        cfg_arch['alpha'] = 1.0

    model = cramming.construct_model(cfg_arch, tokenizer).to(device)
    model = cramming.backend.load_model_checkpoint(model, model_file)
    model.to(device)
    model.eval()

    try:
        temp = cfg.temp
    except:
        temp = 1.0
    log.info(f"temperature = {temp}")

    try:
        greedy = cfg.greedy
    except:
        greedy = True
    log.info(f"greedy = {greedy}, note: if greedy = True this overrides any temperature arguments")
    ## Greedy deconding will overide any temperature arguments


    try:
        max_size_given = cfg.max_size_given
    except:
        max_size_given = None

    if max_size_given is not None:
        max_size = max_size_given

    # Grid plots - grid search from 1x1 to 12x12 data
    data_sizes = list(range(1, max_size))
    acc_grid = np.zeros((len(data_sizes),len(data_sizes)))
    start_ind_1 = 0
    start_ind_2 = 0
    tuple_method = False
    completed_one = False
    if "big_eval" in name:
        tuple_method = True
        # go up two layers and search for grid
        try:
            with open(f"../../accs_grid_quick{name}.json", 'r') as file:
                data = json.load(file)
            start_ind_1 = data[1]
            start_ind_2 = data[2]
            acc_grid = np.array(data[0])
            log.info("loaded grid from previous run")
        except:
            pass


    try:
        start_ind_1_given = cfg.start_ind_1_given
    except:
        start_ind_1_given = None

    if start_ind_1_given is not None:
        start_ind_1 = start_ind_1_given


    try:
        start_ind_2_given = cfg.start_ind_2_given
    except:
        start_ind_2_given = None

    if start_ind_2_given is not None:
        start_ind_2 = start_ind_2_given


    try:
        write_strikes = cfg.write_strikes
    except:
        write_strikes = False


    try:
        hurry_up = cfg.hurry_up
    except:
        hurry_up = False


    if hurry_up:
        folder_path = f"../../all_outputs"
        os.makedirs(folder_path, exist_ok=True)

    os.makedirs("outputs", exist_ok=True)
    if use_del:
        os.makedirs("output_strikes", exist_ok=True)

    print(f"start_ind_1 = {start_ind_1}, start_ind_2 = {start_ind_2}")
    print(data_sizes)

    if not extended_eval:
        for data_size_1 in data_sizes:
            for data_size_2 in data_sizes:
                if not hurry_up:
                    if (data_size_1 < start_ind_1 or data_size_2 < start_ind_2) and not completed_one:
                        continue
                else:
                    proceed = False
                    # if both data sizes are less than the start indices, then dont proceed
                    # but if one of them is greater than the start indices, then proceed
                    if data_size_1 >= start_ind_1 or data_size_2 >= start_ind_2:
                        proceed = True

                    if not proceed:
                        continue

                print(f"evaluating for {data_size_1} and {data_size_2}")

                if hurry_up:
                    if os.path.exists(f"../../all_outputs/acc_for_{data_size_1}_{data_size_2}.txt"):
                        with open(f"../../all_outputs/acc_for_{data_size_1}_{data_size_2}.txt", 'r') as file:
                            acc = float(file.read())
                        acc_grid[data_size_1 - 1, data_size_2 - 1] = acc
                        continue


                if logic_func(data_size_1, data_size_2):
                    completed_one = True
                    log.info(f"Starting iteration in grid eval for size: {data_size_1} and {data_size_2}")
                    correct_total = 0

                    file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_padded_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset"
                    if reverse_inputs:
                        file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_seed_42/hf_tokenized_dataset"
                    if mul:
                        file_path = f"../../../../data/arithmetic_data/x_grid_eval_dataset_2_reverse_all_tokenized/x_n_{data_size_1}_m_{data_size_2}_examples_100_diff_lens_exact_seed_91/hf_tokenized_dataset"
                    if pos_arth or pos_arth_ood:
                        # file_path = f"../../../../data/arithmetic_data/pos_arith_eval/pos_arith_{data_size_1}_{data_size_2}/hf_tokenized_dataset"
                        file_path = f"../../../../data/arithmetic_data/pos_or_one_vec_zeros_eval/or_one_vec_zeros_{data_size_1}_{data_size_2}/hf_tokenized_dataset"
                    if pos_arth_add_rev or pos_arth_add_rev_ood:
                        file_path = f"../../../../data/arithmetic_data/pos_arith_add_rev_eval/pos_arith_add_rev_{data_size_1}_{data_size_2}/hf_tokenized_dataset"
                    if pos_arth_add_forward or pos_arth_add_forward_ood:
                        file_path = f"../../../../data/arithmetic_data/pos_arith_add_forward_eval/pos_arith_add_forward_{data_size_1}_{data_size_2}/hf_tokenized_dataset"
                    tokenized_dataset = datasets.load_from_disk(file_path)["test"]
                    data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=100, shuffle=False)
                    equals_tensor = data_size_1+data_size_2+6
                    if pos_arth or pos_arth_ood or pos_arth_add_rev or pos_arth_add_rev_ood or pos_arth_add_forward or pos_arth_add_forward:
                        equals_tensor = data_size_1+data_size_2+2

                    for batch in data_loader:
                        tokenized_prompts = batch["input_ids"][:equals_tensor]
                        tokenized_prompts = torch.stack(tokenized_prompts).to(device)
                        tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)
                        tokenized_answers = batch["input_ids"][equals_tensor:]
                        tokenized_answers = torch.stack(tokenized_answers).to(device)
                        tokenized_answers = torch.transpose(tokenized_answers, 0, 1)
   
                        if remove_padding and (cfg_data_sources_values_list["tokenizer_type"] != "gdm_index"): # only tested for forwards addition 
                            num1 = tokenized_prompts[:,:data_size_1]
                            op = tokenized_prompts[:,data_size_1+1:data_size_1+2]
                            num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]
                            equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]
                            tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)
                        if add_random_pad:
                            num1 = tokenized_prompts[:,:data_size_1]
                            op = tokenized_prompts[:,data_size_1+1:data_size_1+2]
                            num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]
                            equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]
                            tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)
                            p = 0.09
                            indexes_to_pad = []
                            for index in range(tokenized_prompts.shape[1]):
                                this_p = p
                                while random.random() <= this_p:
                                    this_p *= 0.1
                                    indexes_to_pad.append(index)

                            adjusted_indices = [index + i for i, index in enumerate(indexes_to_pad)]
                            result_large = tokenized_prompts
                            for index in adjusted_indices:
                                result_large = torch.cat((result_large[:, :index], torch.zeros(result_large.size(0), 1, dtype=result_large.dtype, device=result_large.device), result_large[:, index:]), 1)
                            
                        if cfg_data_sources_values_list["tokenizer_type"] == "gdm_index":
                            # adding in the index hints to the input numbers
                            num1 = tokenized_prompts[:,:data_size_1]
                            num1 = gdm_index_hints_helper(num1, tokenizer)
                            op = tokenized_prompts[:,data_size_1+1:data_size_1+2]
                            num2 = tokenized_prompts[:,data_size_1+3:data_size_1+data_size_2+3]
                            num2 = gdm_index_hints_helper(num2, tokenizer)
                            equals = tokenized_prompts[:,data_size_1+data_size_2+4:data_size_1+data_size_2+5]
                            tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)

                            predicted_ids = None

                            ## below inserts the characters for the model, we decided against this in the end
                            predicted_ids = model._generate(tokenized_prompts, token_limit=(tokenized_answers.shape[1]*2), temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=greedy, quick=True)
                            predicted_ids = torch.transpose(predicted_ids, 0, 1)

                            new_tensor = torch.zeros_like(predicted_ids)
                            for i in range(predicted_ids.size(0)): # inefficient!!
                                # Filter out values greater than 17
                                filtered_values = predicted_ids[i][predicted_ids[i] <= 17]
                                # Place filtered values in new tensor and pad with zeros
                                new_tensor[i, :len(filtered_values)] = filtered_values

                            predicted_ids = new_tensor[:, :tokenized_answers.shape[1]] # trim off the excess
                            predicted_ids = torch.transpose(predicted_ids, 0, 1)

                        else:
                            predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=greedy, quick=True)
                        if len(predicted_ids.shape) > 1: # i.e. we have a batch of more than one
                            predicted_ids = torch.transpose(predicted_ids, 0, 1)
                        else:
                            predicted_ids = predicted_ids.reshape((1,-1)) # add a batch dim otherwise
                            # print(tokenized_prompts)
                            # print(tokenized_answers)
                            # print(predicted_ids)
                            # # print(tokenized_prompts.shape)
                            # print(tokenized_answers.shape)
                            # print(predicted_ids.shape)
                            # print((tokenized_answers == 4).sum().item())
                            # print((predicted_ids == 4).sum().item())
                        # ignore everything after EOS on eval but replacing all after EOS with PAD
                    eval_tensor = predicted_ids.clone()
                    input_tensor_EOS = (eval_tensor == EOS_token).int()
                    indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)
                    mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]
                    eval_tensor[mask] = PAD_token
                    
                    # compare eval tensor to correct outputs
                    # The origianal (now corrected) pos arith datasets were one 0 too short during training so we had to change the eval data to match
                    # if pos_arth or pos_arth_ood:
                    #     tokenized_answers = torch.cat((tokenized_answers[:, :-2], tokenized_answers[:, -1:]), dim=1)
                    #     eval_tensor = eval_tensor[:, :-1]
                    elementwise_equal = torch.eq(eval_tensor, tokenized_answers)
                    rows_equal = torch.all(elementwise_equal, dim=1)
                    num_equal_rows = torch.sum(rows_equal).item()
                    correct_total += (num_equal_rows/tokenized_prompts.shape[0])
                    log.info(f"accuracy for {data_size_1}, {data_size_2}: {num_equal_rows} = {correct_total*100}%")

                    # combine the prompts and outputs
                    complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)
                    tokens_list = complete_lines.tolist()
                    decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens
                    log.info(f"example for {data_size_1}, {data_size_2}: {decoded_batch[0]}")
                    # save the answers down so we don't eval twice ever
                    with open(f"outputs/+_n_{data_size_1}_m_{data_size_2}.json", 'w') as json_file:
                        json.dump(decoded_batch, json_file)

                    acc_grid[(data_size_1-1),(data_size_2-1)] = correct_total

                    if hurry_up:
                        with open(f"../../all_outputs/acc_for_{data_size_1}_{data_size_2}.txt", "w") as file:
                            file.write(f"{correct_total}")

                    if tuple_method:
                        with open(f"../../accs_grid_quick{name}.json", "w") as file:
                            tuple_to_save = (acc_grid.tolist(),data_size_1,data_size_2)
                            json.dump(tuple_to_save, file)

        log.info(f"acc grid: {acc_grid}")

        with open(f"accs_grid_quick{name}.json", "w") as file:
            json.dump(acc_grid.tolist(), file)
        
        # Grid plots - one for accs one for contains
        grid_plotter(acc_grid, name=name)

        if hurry_up:
            grid_plotter(acc_grid, name=f"{start_ind_1}_{start_ind_2}_{max_size}", extra_path="../../all_outputs/")
            exit()

    if extended_eval:
        number = int(re.findall(r'\d+', name)[0])
        log.info("starting extended eval")
        # this is hard coded for reverse all, addition past 100x100 grid, removing the padding

        accs = dict()
        batch_size_extended_eval = 100

        old_data_path = None
        for root, dirs, files in os.walk("../.."):
            if f"over_100_{number}.json" in files:
                old_data_path = os.path.join(root, f"over_100_{number}.json")

        if number == 1:
            start = 101
            list_to_do = range(start,161)
        elif number == 2:
            list_to_do = [1000, 800]
            batch_size_extended_eval = 10
        elif number == 3:
            list_to_do = [200, 700, 900]
        elif number == 4:
            list_to_do = [300, 400, 500, 600]
        else:
            print("number too high")
            exit()

        if old_data_path is not None: # read the old accs dict and don't repeat what we have already done
            with open(old_data_path, 'r') as file:
                data = json.load(file)
            accs = {int(k): v for k, v in data.items()}
            to_do = set(list_to_do).difference(set(accs.keys()))
            list_to_do = list(to_do)

        log.info(f"In extended eval with number {number}")

        for data_size in list_to_do:
            log.info(f"Extended eval {data_size}")
            correct_total = 0
            file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized_over_100/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_exact_seed_42/hf_tokenized_dataset"
            tokenized_dataset = datasets.load_from_disk(file_path)["test"]
            data_loader = torch.utils.data.DataLoader(tokenized_dataset, batch_size=batch_size_extended_eval, shuffle=False)
            equals_tensor = data_size+data_size+6

            for batch in data_loader:
                tokenized_prompts = batch["input_ids"][:equals_tensor]
                tokenized_prompts = torch.stack(tokenized_prompts).to(device)
                tokenized_prompts = torch.transpose(tokenized_prompts, 0, 1)
                tokenized_answers = batch["input_ids"][equals_tensor:]
                tokenized_answers = torch.stack(tokenized_answers).to(device)
                tokenized_answers = torch.transpose(tokenized_answers, 0, 1)

                num1 = tokenized_prompts[:,:data_size]
                op = tokenized_prompts[:,data_size+1:data_size+2]
                num2 = tokenized_prompts[:,data_size+3:data_size+data_size+3]
                equals = tokenized_prompts[:,data_size+data_size+4:data_size+data_size+5]
                tokenized_prompts = torch.cat((num1, op, num2, equals), dim=1)

                predicted_ids = model._generate(tokenized_prompts, token_limit=tokenized_answers.shape[1], temperature=temp, steps_at_generation_time=cfg_arch.maximal_recurrence_in_eval, greedy=greedy, quick=True)
                predicted_ids = torch.transpose(predicted_ids, 0, 1) # add a batch dim

                eval_tensor = predicted_ids.clone()
                input_tensor_EOS = (eval_tensor == EOS_token).int()
                indices_of_EOS = torch.argmax(input_tensor_EOS, dim=1)
                mask = torch.arange(eval_tensor.size(1)).to(device) > indices_of_EOS[:, None]
                eval_tensor[mask] = PAD_token
                elementwise_equal = torch.eq(eval_tensor, tokenized_answers)
                
                rows_equal = torch.all(elementwise_equal, dim=1)
                num_equal_rows = torch.sum(rows_equal).item()
                correct_total += (num_equal_rows/tokenized_prompts.shape[0])
                log.info(f"accuracy for {data_size}, {data_size}: {num_equal_rows} = {correct_total*100}%")

                # combine the prompts and outputs
                complete_lines = torch.cat((tokenized_prompts,predicted_ids), dim=1)
                tokens_list = complete_lines.tolist()
                decoded_batch = list(map(lambda seq: list(map(lambda token: vocab[token], seq)), tokens_list)) # map token ids to tokens
                log.info(f"example for {data_size}, {data_size}: {decoded_batch[0]}")
                # save the answers down so we don't eval twice ever

            accs[data_size] = correct_total
            with open(f"over_100_{number}.json", 'w') as json_file:
                    json.dump(accs, json_file)

    # Run prompter to get thinking plots too -- one for each data size
    if (not crammed) and (not advanced_plotting) and (not extended_eval):
        if tuple_method: # i.e. for the big eval only make the thinking plots for the first one
            try:
                big_eval_step_1 = cfg.big_eval_step_1
            except:
                big_eval_step_1 = False
            if big_eval_step_1:
                token_limit = 105
            else:
                exit()
        vocab = tokenizer.get_vocab()
        sorted_tokenizer_keys = sorted(vocab, key=lambda x: vocab[x])
        data_sizes = data_sizes = list(range(1, max_size))
        for data_size in data_sizes:
            file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
            if no_carries:
                file_path = f"../../../../data/arithmetic_data/+_no_carries_grid_eval_dataset/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
            elif reverse_inputs:
                file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
                reverse_input_in_tokenise = False
                reverse_input_in_answer = False
            elif mul:
                file_path = f"../../../../data/arithmetic_data/x_grid_eval_dataset_2/x_n_{data_size}_m_{data_size}_examples_100_diff_lens_exact_seed_91.txt"
                reverse_input_in_tokenise = True
                reverse_input_in_answer = True
            elif pos_arth or pos_arth_ood:
                    # file_path = f"../../../../data/arithmetic_data/pos_arith_eval/pos_arith_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_n_{data_size}.txt"
                    file_path = f"../../../../data/arithmetic_data/pos_or_one_vec_zeros_eval/or_one_vec_zeros_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_m_{data_size}.txt"
            elif pos_arth_add_rev or pos_arth_add_rev_ood:
                file_path = f"../../../../data/arithmetic_data/pos_arith_add_rev_eval/pos_arith_add_rev_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_m_{data_size}.txt"
            elif pos_arth_add_forward or pos_arth_add_forward_ood:
                file_path = f"../../../../data/arithmetic_data/pos_arith_add_forward_eval/pos_arith_add_forward_{data_size}_{data_size}/positional_arithmetic_n_{data_size}_m_{data_size}.txt"
            with open(file_path, "r") as file:
                for line in file:
                    prompt, tokenized_prompt, answer = tokenize_and_split_line(line.strip(), tokenizer, device, reverse_inputs=reverse_input_in_tokenise, pad_zeros=pad_zeros, remove_padding=remove_padding)
                    if use_del:
                        prompt = prompt[:-1]
                    tracking, verbose_completion, logits = verbose_prompt_eval(tokenized_prompt, answer, model, tokenizer, cfg_arch, token_limit=token_limit, temp=temp, greedy=greedy)
                    if reverse_input_in_answer:
                        answer = str(answer)[::-1]
                    log.info(f"Prompt: {prompt}")
                    log.info(f"Answer: {answer}")
                    log.info(f"Verbose completion: {verbose_completion}")

                    data = {
                        "gen_tokens": tracking,
                        "logits": logits,
                    }
                    multi_plotter(data, data_size, prompt, answer, verbose_completion, sorted_tokenizer_keys, n_locs=token_limit, n_times=cfg_arch.maximal_recurrence_in_eval)
                    # thinking_plotter(tracking, data_size, prompt, answer, verbose_completion)
                    break
                    
    elif advanced_plotting:
        vocab = tokenizer.get_vocab()
        sorted_tokenizer_keys = sorted(vocab, key=lambda x: vocab[x])
        data_sizes = data_sizes = list(range(20, max_size))
        for data_size in data_sizes:
            
            file_path = f"../../../../data/arithmetic_data/+_grid_eval_dataset_reverse_all_tokenized/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42/+_n_{data_size}_m_{data_size}_examples_100_diff_lens_seed_42.txt"
            reverse_input_in_tokenise = False
            reverse_input_in_answer = False
            with open(file_path, "r") as file:
                for line in file:
                    prompt, tokenized_prompt, answer = tokenize_and_split_line(line.strip(), tokenizer, device, reverse_inputs=reverse_input_in_tokenise, pad_zeros=pad_zeros, remove_padding=remove_padding)
                    em_acc = []
                    for rec in range(1,cfg.max_rec+1):
                        cfg_arch.maximal_recurrence_in_eval = rec
                        tracking, verbose_completion, logits = verbose_prompt_eval(tokenized_prompt, answer, model, tokenizer, cfg_arch, token_limit=token_limit, temp=temp, greedy=greedy)
                        answer = list(answer)
                        verbose_up_to_eos = verbose_completion[:next((i for i, x in enumerate(verbose_completion) if x == '[EOS]'), len(verbose_completion))]
                        count = sum(1 for x, y in zip(answer, verbose_up_to_eos) if x == y)/len(answer)
                        em_acc.append(count)
                    # plot
                    recurrences = range(1, cfg.max_rec+1)
                    plt.figure(figsize=(8, 6))
                    plt.plot(recurrences, em_acc, marker='o', linestyle='-', color='b')
                    plt.title(f'EM Accuracy by Number of Recurrences, {prompt}{"".join(verbose_up_to_eos)}')
                    plt.xlabel('Number of Recurrences')
                    plt.ylabel('EM Accuracy')
                    plt.ylim(0, 1.1)
                    plt.xticks(recurrences)
                    plt.savefig(f"em_acc_plot_{data_size}", bbox_inches='tight')
                    plt.clf()
                    break

    log.info("Eval complete")

@hydra.main(config_path="cramming/config", config_name="cfg_eval", version_base="1.3")
def launch(cfg):
    log.info("calling main launch")
    cfg = cramming.utils.pathfinder(cfg)
    log.info(OmegaConf.to_yaml(cfg, resolve=True))
    main(cfg)

if __name__ == "__main__":
    launch()