# External Modules
import os
import argparse

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("uuid", type=str, help="UUID of the model to self train")
argument_parser.add_argument("num_digits", type=int, help="Number of digits to self train on")
args = argument_parser.parse_args()

from IPython import embed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from pytorch_lightning import Trainer
import numpy as np
import torch
import logging
import os
import socket
import argparse
import sys

# Internal Modules
import generate_data
from models import AdditionFlanT5
import utils
import random

from generate_data import MultiDigitAdditionDataset

host_name = socket.gethostname()

CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/".format(host_name) + args.uuid
MODEL = "google/byt5-{}".format("small")

def generate_random_n_digit_number(n):
    # The first digit should not be 0
    first_digit = random.randint(1, 9)
    remaining_digits = [random.randint(0, 9) for _ in range(n - 1)]
    
    # Convert digits to a single number
    number = first_digit
    for digit in remaining_digits:
        number = number * 10 + digit
        
    return number

def main():

    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    tokenizer.add_special_tokens({"additional_special_tokens": ["<scratchpad>", "</scratchpad>"] + tokenizer.additional_special_tokens})
    trainer = Trainer(accelerator="gpu", default_root_dir="logs/self_train", enable_progress_bar=False)
    print("Checkpoint dir: {}".format(CHECKPOINT_DIR))

    num_digits = 3
    lightning_model = AdditionFlanT5(type="decomp", silent=True)
    lightning_model.initialize(lightning_model.model, tokenizer, None)

    # model_location = CHECKPOINT_DIR + "/model-supervised-{}-digits.ckpt".format(num_digits)
    model_location = CHECKPOINT_DIR + "/model-{}-{}-digits.ckpt".format("selftrain", args.num_digits)
    # model_location = CHECKPOINT_DIR + "/model-failsafe.ckpt"

    lightning_model.set_cur_num_digits(num_digits)
    lightning_model.model = AutoModelForSeq2SeqLM.from_pretrained(model_location)

    reference = MultiDigitAdditionDataset(10, args.num_digits, 0, "english", tokenizer, use_flash=True)

    correct = 0
    for _ in range(100):

        # Generate a random args.num_digit length number
        num1 = generate_random_n_digit_number(args.num_digits)
        num2 = generate_random_n_digit_number(args.num_digits)

        # Perform inference on model
        input_sentence = "Add fast.\nQ: {}+{}=?".format(num1, num2)

        input_ids = tokenizer.encode(input_sentence, return_tensors="pt")
        input_ids = input_ids.to(lightning_model.device)

        # Generate at temperature 0 
        output_ids = lightning_model.model.generate(input_ids, max_length=100, temperature=0.0)

        output_sentence = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        output_answer = int(output_sentence.split(":")[-1].strip())

        if output_answer == num1 + num2:
            correct += 1
    
    print("Accuracy: {}".format(correct / 100))

    embed()




if __name__ == "__main__":
    main()