from transformers import GPT2Tokenizer, AutoModelForCausalLM
import numpy as np
from IPython import embed
import os
import pickle
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from models import AdditionFlanT5

from generate_data import MultiDigitAdditionDataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import argparse
import socket
import generate_data

argparser = argparse.ArgumentParser()
argparser.add_argument("uuid", type=str)
argparser.add_argument("num_digits", type=int)

args = argparser.parse_args()

hostname = socket.gethostname()
LOCATION = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/{}/data/".format(hostname, args.uuid)

def main():
    num_digits = args.num_digits
    next_file_exists = True
    while next_file_exists:
        next_file_exists = False
        total_flash_correct, total_flash_count, total_decomp_correct, total_decomp_count = 0, 0, 0, 0
        for device in range(8):
            flash_datafile = "{}digits:{}_device:{}_{}_dataset.pkl".format(LOCATION, num_digits, device, "flash")
            decomp_datafile = "{}digits:{}_device:{}_{}_dataset.pkl".format(LOCATION, num_digits, device, "decomp")

            if os.path.exists(flash_datafile) and os.path.exists(decomp_datafile):
                next_file_exists = True

                flash_correct, decompose_correct, flash_total, decompose_total = check_dataset(flash_datafile, decomp_datafile)
                total_flash_correct += flash_correct
                total_flash_count += flash_total
                total_decomp_correct += decompose_correct
                total_decomp_count += decompose_total
            else:
                print("No file found at location {}".format(flash_datafile))
                exit(0)

        print("{} digit flash correct: {} / {} for accuracy {}".format(num_digits, total_flash_correct, total_flash_count, total_flash_correct / total_flash_count))
        print("{} digit decomp correct: {} / {} for accuracy {}".format(num_digits, total_decomp_correct, total_decomp_count, total_decomp_correct / total_decomp_count))
        num_digits += 1

def check_dataset(flash_datafile, decomp_datafile):

    with open(flash_datafile, "rb") as f:
        flash_dataset, tokenized_dataset, max_output_length = pickle.load(f)

        flash_correct = 0
        for question, answer, numerical_answer in flash_dataset:
            numbers = [generate_data.remove_non_numeric(a) for a in question.split("+")]

            flash_correct += (int(numerical_answer) == sum(numbers))
    
    with open(decomp_datafile, "rb") as f:
        decomp_dataset, tokenized_dataset, max_output_length = pickle.load(f)

        decomp_correct = 0
        for question, answer, numerical_answer in decomp_dataset:
            numbers = [generate_data.remove_non_numeric(a) for a in question.split("+")]
            num1, num2 = numbers[0], numbers[1]
            carry = 0 if len(numbers) == 2 else numbers[2]

            correct_answer = "A: " + generate_data.generate_solution_english(num1, num2, carry, "", decomp_only=True)

            decomp_correct += (answer == correct_answer)

    return flash_correct, decomp_correct, len(flash_dataset), len(decomp_dataset)
            
if __name__ == "__main__":
    main()