import argparse
import pickle
import generate_data
import numpy as np
import logging
from IPython import embed
import os
import pandas as pd
from torch.utils.data import Dataset
from collections import defaultdict
import torch
import random
import copy
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from collections import Counter
import logging
from tqdm import tqdm
import argparse
import pickle
import gc 
import socket

argparser = argparse.ArgumentParser()
argparser.add_argument("uuid", type=str)
argparser.add_argument("--primary_num_digits", type=int, default=3)
argparser.add_argument("--traintype", type=str, default="supervised")
argparser.add_argument("--fast", action="store_true", default=False)
argparser.add_argument("--num_examples", type=int, default=300) # 10128 / 8 / 5
argparser.add_argument("--device", type=int, default=0)
argparser.add_argument("--num_old_examples", type=int, default=0)
argparser.add_argument("--batch_size", type=int, default=0)
argparser.add_argument("--size", type=str, default="small")

args = argparser.parse_args()
print(args)


if args.batch_size == 0:

    # I'm not sure why batch size doesn't seem to correlate with speed
    if args.size == "small":
        args.batch_size = 512
    elif args.size == "base":
        args.batch_size = 128
    else:
        args.batch_size = 32

if args.fast:
    args.num_examples = 20

def main():

    host_name = socket.gethostname()
    CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/{}/".format(host_name, args.uuid)
    checkpoint = os.path.join(CHECKPOINT_DIR, "model-{}-{}-digits.ckpt".format(args.traintype, args.primary_num_digits-1))
    if args.fast:
        data_dir = CHECKPOINT_DIR + "fastdata/"

        if args.traintype == "selftrain":
            checkpoint += ".fast"
    else:
        data_dir = CHECKPOINT_DIR + "data/"
    #checkpoint = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/{}/model-{}-{}-digits.ckpt".format(host_name, args.uuid, args.traintype, args.primary_num_digits-1)
    nn_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, torch_dtype=torch.float16)

    tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")


    # Device is the GPU with number args.device if available, otherwise CPU
    device = torch.device("cuda:{}".format(args.device) if torch.cuda.is_available() else "cpu")
    # full_dataset = generate_data.MultiDigitAdditionDataset(1128, args.primary_num_digits, 1000, "english", tokenizer, model=nn_model, use_flash=True, batch_size=BS, device=device)

    full_dataset = generate_data.MultiDigitAdditionDataset(args.num_examples, args.primary_num_digits, args.num_old_examples, "english", tokenizer,
                                                           model=nn_model, use_flash=True, batch_size=args.batch_size, device=device, uuid=args.uuid, force_min_number_examples=0)
    # Make directory if one doesn't exist
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    decomp_file = data_dir + "digits:{}_device:{}_decomp_dataset.pkl".format(args.primary_num_digits, args.device)
    flash_file = data_dir + "digits:{}_device:{}_flash_dataset.pkl".format(args.primary_num_digits, args.device)

    print("Max output length of tokens is {}".format(full_dataset.pad_length))
    print("Max output length of last dataset is {}".format(full_dataset.addition_datasets[-1].max_output_length))

    # Don't run into errors with padding because you only have 2 generated examples
    if args.fast:
        full_dataset.addition_datasets[-1].max_output_length = 512
        full_dataset.flash_datasets[-1].max_output_length = 512

    
    full_dataset.addition_datasets[-1].write_to_file(decomp_file)
    full_dataset.flash_datasets[-1].write_to_file(flash_file)
    
    while not os.path.exists(decomp_file):
        print("File not found at location {}".format(decomp_file))
        full_dataset.addition_datasets[-1].write_to_file(decomp_file)
    
    while not os.path.exists(flash_file):
        print("File not found at location {}".format(flash_file))
        full_dataset.flash_datasets[-1].write_to_file(flash_file)

    print("Generated {} examples".format(len(full_dataset)))

if __name__ == "__main__":
    main()