# External Modules
from datasets import load_dataset
from torch.utils.data import DataLoader
from IPython import embed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import wandb
import torch
import logging
import os
import socket
import argparse
import sys

# Internal Modules
import generate_data
from models import AdditionFlanT5
import utils

# wandb.login()
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
host_name = socket.gethostname()

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("uuid", type=str, help="UUID of the model to self train")
argument_parser.add_argument("--gpu", action="store_true", help="Use GPU", default=False)
argument_parser.add_argument("--type", type=str, default="decomp", help="Type of dataset to self train on")
argument_parser.add_argument("--size", type=str, default="small", help="Size of the model to self train")
argument_parser.add_argument("--batch_size", type=int, default=10, help="Batch size to use for self training")
argument_parser.add_argument("--num_digits_start", type=int, default=3, help="Number of digits to start with")
argument_parser.add_argument("--flash", default=False, action="store_true", help="Flash self training")
argument_parser.add_argument("--skip_check", default=False, action="store_true", help="Skip checking")
argument_parser.add_argument("--silent", default=False, action="store_true", help="Be silent")
argument_parser.add_argument("--traintype", default="supervised", type=str, help="Type of training to use")
args = argument_parser.parse_args()

CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/".format(host_name) + args.uuid
MODEL = "google/byt5-{}".format(args.size)

logging.basicConfig(
        stream=sys.stdout, 
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')

def main():
    if not args.gpu:
        torch.cuda.is_available = lambda : False
        device = "cpu"
    else:
        device = "gpu"

    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    tokenizer.add_special_tokens({"additional_special_tokens": ["<scratchpad>", "</scratchpad>"] + tokenizer.additional_special_tokens})
    trainer = Trainer(accelerator=device, default_root_dir="logs/self_train", enable_progress_bar=False)
    print("Checkpoint dir: {}".format(CHECKPOINT_DIR))

    num_digits = args.num_digits_start
    lightning_model = AdditionFlanT5(type=args.type, silent=args.silent)
    lightning_model.initialize(lightning_model.model, tokenizer, None)

    while True:
        # model_location = CHECKPOINT_DIR + "/model-supervised-{}-digits.ckpt".format(num_digits)
        model_location = CHECKPOINT_DIR + "/model-{}-{}-digits.ckpt".format(args.traintype, num_digits)
        # model_location = CHECKPOINT_DIR + "/model-failsafe.ckpt"

        if not os.path.exists(model_location):
            print("No {} digit model found, stopping".format(num_digits))
            break

        lightning_model.set_cur_num_digits(num_digits)
        lightning_model.model = AutoModelForSeq2SeqLM.from_pretrained(model_location)

        if not args.skip_check:
            perfect_acc, accs = utils.verify_accuracy(trainer, lightning_model, num_digits, args.type, batch_size=args.batch_size, flash=args.flash, silent=args.silent)

            if perfect_acc:
                print("{}-digit model is sufficiently trained.".format(num_digits))
            else:
                print("{}-digit model is not sufficiently trained. Acc {}".format(num_digits, accs))
        
        # Now we check extrapolation
        for test_digits in range(num_digits + 1, 51):
            lightning_model.set_cur_num_digits(test_digits)
            lightning_model.accuracy_check_threshold = np.inf
            perfect_acc, all_accs = utils.verify_accuracy(trainer, lightning_model, test_digits, args.type, batch_size=args.batch_size, digit_only=True, flash=False, silent=args.silent)
            if not perfect_acc:
                break

        print("{}-digit model perfectly generalizes to {} digit and gets accuracy {} on {} digits".format(num_digits, test_digits - 1, all_accs[-1], test_digits))

        num_digits += 1
        

if __name__ == "__main__":
    main()