# Utils must be imported first before the transformers library
import utils


# External Modules
from datasets import load_dataset
from torch.utils.data import DataLoader
from IPython import embed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
from tqdm import tqdm

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DeepSpeedStrategy
from lightning_fabric.utilities.seed import seed_everything
from peft import get_peft_model, LoraConfig, TaskType

import wandb
import random
import copy
import torch
import logging
import numpy as np
import os
import uuid
from datetime import datetime
from pytz import timezone
import socket
import argparse
import git
import gc
import json
import multiprocessing
import sys
import pickle

# Internal Modules
import generate_data
from models import AdditionFlanT5

NUM_DEVICES = torch.cuda.device_count()
DEFAULT_BATCH_SIZE = 8 * NUM_DEVICES

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--checkpoint", type=str, default=None) 
argument_parser.add_argument("--num_digits_start", type=int, default=3)
argument_parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE)
argument_parser.add_argument("--size", type=str, default="small")
argument_parser.add_argument("--uuid", type=str, default=None)
argument_parser.add_argument("--seed", type=int, default=0)
argument_parser.add_argument("--type", type=str, default="decomp", help="Currently implemented types are 'decomp', 'full', and 'remove'")

argument_parser.add_argument("--debug", action="store_true", default=False)
argument_parser.add_argument("--fast", action="store_true", default=False)
argument_parser.add_argument("--no_flash", action="store_true", default=False)
argument_parser.add_argument("--alpaca", action="store_true", default=False)
argument_parser.add_argument("--wandb_id", type=str, default=None)
argument_parser.add_argument("--generate_data", action="store_true", default=False)
argument_parser.add_argument("--read_data_from", type=str, default=None)
argument_parser.add_argument("--generate_digit_start", type=int, default=0)

args = argument_parser.parse_args()

tz = timezone('EST')
if args.uuid is None:
    UUID = str(uuid.uuid4())
else:
    UUID = args.uuid

host_name = socket.gethostname()
CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/".format(host_name) + UUID
CONFIG_FILE = CHECKPOINT_DIR + "/config.json"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

NUM_TRAINING_EXAMPLES = 2000 if args.generate_digit_start > 0 else 10000
SAVE_FREQUENCY = 70 * NUM_DEVICES if args.num_digits_start == 3 else 25 * NUM_DEVICES
TRAINING_TYPE = "english"
VAL_SIZE = 128
MAX_EPOCHS = 10
SEED = args.seed
# MODEL = "google/flan-t5-{}".format(args.size)
MODEL = "google/byt5-{}".format(args.size)

if args.fast:
    NUM_TRAINING_EXAMPLES = 1
    SAVE_FREQUENCY = 10
    VAL_SIZE = 1
    MAX_EPOCHS = 2

if args.debug:
    logging.basicConfig(
        stream=sys.stdout, 
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')
else:
    logging.basicConfig(
        filename="logs/supervised_train/{}_supervisedtrain.log".format(UUID),
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')
    print("Logging to {}".format(logging.root.handlers[0].baseFilename))

os.environ['WANDB_DISABLED'] = 'true'
MASTER_LOCATION = os.path.join(CHECKPOINT_DIR, "master.log")
def setup(true_batch_size, checkpoint):
    wandb.login()
    
    if args.alpaca:
        tokenizer = LlamaTokenizer.from_pretrained("chainyo/alpaca-lora-7b")
        tokenizer.add_special_tokens({'pad_token': '<PAD>'})
    else:
        tokenizer = AutoTokenizer.from_pretrained(MODEL)

    tokenizer.add_special_tokens({"additional_special_tokens": ["<scratchpad>", "</scratchpad>", "|||"] + tokenizer.additional_special_tokens})
    logging.info("Loading model and tokenizer")

    wandb_logger = WandbLogger(name=str(datetime.now(tz)), project='SupervisedAdditionRLDeepSpeedEra', id=args.wandb_id)
    write_to_master_log("Wandb ID: {}".format(wandb_logger.experiment.id))
    
    model = setup_model(wandb_logger, tokenizer, true_batch_size, checkpoint)
    trainer = setup_trainer(wandb_logger, true_batch_size)
    logging.info("Loaded model and tokenizer")

    return model, tokenizer, trainer, wandb_logger

def setup_model(wandb_logger, tokenizer, true_batch_size, checkpoint):
    model = AdditionFlanT5(type=args.type)

    if not checkpoint:
        if args.alpaca:
            base_model = LlamaForCausalLM.from_pretrained(
                "chainyo/alpaca-lora-7b",
                load_in_8bit=True,
                torch_dtype=torch.float16,
                device_map="auto",
            )
            nn_model = base_model

        else:
            nn_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL,
                                                            torch_dtype=torch.bfloat16,
                                                            )
    else:
        nn_model = None

    model.initialize(model=nn_model, tokenizer=tokenizer, wandb_logger=wandb_logger)
    model.hparams.batch_size = true_batch_size // NUM_DEVICES
    return model

def setup_trainer(wandb_logger, true_batch_size):

    early_stop_callback = EarlyStopping(
        monitor="min_val_acc",
        patience=100,
        verbose=True,
        min_delta=0.01,
        mode='max',
        stopping_threshold=0.999 if not args.fast else 0.001,
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    # checkpoint_callback = ModelCheckpoint(dirpath=CHECKPOINT_DIR, save_top_k=1, every_n_train_steps=SAVE_FREQUENCY, monitor="loss", filename="supervised-{epoch:02d}-{global_step}",)
    # `logging.info("Saving model at {}".format(checkpoint_callback.dirpath))

    strategy=DeepSpeedStrategy(
        stage=2,
        offload_optimizer=True,
        # logging_level = logging.INFO,
    )

    precision = "16-mixed" if args.alpaca else "bf16-mixed"

    num_gpus = NUM_DEVICES if not args.generate_data else NUM_DEVICES - 1

    trainer = Trainer(accelerator="gpu",
                      devices=num_gpus,
                      strategy=strategy,
                      default_root_dir="/home/azureuser/cloudfiles/code/Users/DIRNAME/addition/logs/supervised_train/",
                      logger=wandb_logger, 
                      val_check_interval = SAVE_FREQUENCY,
                      precision=precision,

                      # auto_scale_batch_size='power',
                      # auto_lr_find=True,
                      enable_checkpointing=False,
                      callbacks=[lr_monitor, early_stop_callback],#, checkpoint_callback],
                      accumulate_grad_batches=1,

                      max_epochs=MAX_EPOCHS,
                      max_steps=-1)
    
    trainer.accumulate_grad_batches = max(1, DEFAULT_BATCH_SIZE // true_batch_size)
    return trainer

def generate_trainset(num_digits, tokenizer, use_flash, model=None, batch_size=-1, trainer=None):


    stop_generation_digit = np.inf if args.generate_digit_start == 0 else args.generate_digit_start
    num_new_examples = NUM_TRAINING_EXAMPLES + VAL_SIZE if not args.read_data_from else 0

    # Need 2 for at least one validation, one train
    num_old_examples = max(2, int(0.1*NUM_TRAINING_EXAMPLES))

    full_dataset = generate_data.MultiDigitAdditionDataset(num_examples=num_new_examples,
                                                           primary_num_digits=num_digits,
                                                           num_old_examples=num_old_examples,
                                                           dataset_type=TRAINING_TYPE,
                                                           tokenizer=tokenizer,
                                                           type=args.type,
                                                           force_min_number_examples=max(2, int(0.3*NUM_TRAINING_EXAMPLES)),
                                                           use_flash=use_flash,
                                                           model=model,
                                                           batch_size=batch_size,
                                                           stop_generation_digit=stop_generation_digit,
                                                           device=None,
                                                           uuid=UUID)

    if args.read_data_from:
        print("READ DATA FROM " + args.read_data_from)
        assert args.generate_digit_start > 0 and args.generate_digit_start <= num_digits, \
                "Must specify generate_digit_start when reading data from file"

        for digit in tqdm(range(args.generate_digit_start, num_digits+1)):

            # Clear out all the generated data 
            full_dataset.addition_datasets[digit - 1].clear()
            if use_flash:
                full_dataset.flash_datasets[digit - 1].clear()

            for i in range(NUM_DEVICES):

                try:

                    basefile = args.read_data_from.replace("NUMDIGITS", str(digit)) \
                                                     .replace("DEVICENUMBER", str(i))
                    
                    decomp_file = basefile.replace("DATATYPE", "decomp")
                    flash_file = basefile.replace("DATATYPE", "flash")
                    
                    full_dataset.load_from_file(digit, decomp_file, "decomp")
                    if use_flash:
                        full_dataset.load_from_file(digit, flash_file, "flash")

                except Exception as e:
                    print("Failed to load data from digit {} and device {}".format(digit, i))
                    print(str(e))
                    continue
        
    print("Full dataset has length {}".format(len(full_dataset)))
    val_datasets = full_dataset.split(VAL_SIZE)

    return full_dataset, val_datasets

def load_from_config():

    # Default values for num_digits, checkpoint, batch_size, total_steps
    if not os.path.exists(CONFIG_FILE):
        return args.num_digits_start, None, DEFAULT_BATCH_SIZE, 0

    with open(CONFIG_FILE, 'r') as f:
        config = json.load(f)
    
    num_digits = config['num_digits']
    checkpoint = config['checkpoint']
    batch_size = config['batch_size']
    total_steps = config['total_steps']

    logging.info("Loaded config {}".format(config))

    return num_digits, checkpoint, batch_size, total_steps

# Write to config file and restart the program
def write_to_config(*, num_digits, checkpoint, batch_size, total_steps, message):

    def convert_types(obj):
        if isinstance(obj, np.integer):
            return int(obj)
        raise TypeError

    config_json = {'num_digits': num_digits, 'checkpoint': checkpoint, 'batch_size': batch_size, 'total_steps': total_steps}
    with open(CONFIG_FILE, 'w') as f:
        json.dump(config_json, f, default=convert_types)
    
    logging.info(message)
    write_to_master_log(message + "\nWrote to config file: {}".format(config_json))

def write_to_master_log(message):

    # Only log on the main process
    if logging.getLogger().getEffectiveLevel() <= logging.INFO:
        MASTER_LOCATION = os.path.join(CHECKPOINT_DIR, "master.log")
        with open(MASTER_LOCATION, "a+") as w:
            w.write(message + "\n")

def prepare_dataloaders(model, tokenizer, trainer, num_digits, true_batch_size, use_flash, checkpoint):
    model.set_cur_num_digits(num_digits)

    # generate_model = model.model if args.genI# erate_data else None

    generate_batch_size = true_batch_size if args.generate_data else -1
    train_dataset, val_datasets = generate_trainset(num_digits, tokenizer, use_flash=use_flash, model=None, batch_size=generate_batch_size, trainer=trainer)
    
    logging.info("Example of training data")
    logging.info(tokenizer.decode(train_dataset[0]['input_ids']))
    logging.info(tokenizer.decode(train_dataset[0]['labels'][train_dataset[0]['labels']>=0]))
    # logging.info(str(train_dataset[0]['numerical_answer']))

    trainer.val_check_interval = min(SAVE_FREQUENCY, len(train_dataset) // true_batch_size)

    if args.fast:
        trainer.val_check_interval = 2
    
    logging.info("Trying to learn {} digits with true batch size {}".format(num_digits, true_batch_size))
    train_loader = DataLoader(train_dataset, batch_size=model.hparams.batch_size, shuffle=True, pin_memory=True, num_workers=NUM_DEVICES*2)
    val_loader = [DataLoader(val_dataset, batch_size=16, shuffle=False, pin_memory=True, num_workers=NUM_DEVICES) for val_dataset in val_datasets]

    return train_loader, val_loader, train_dataset, val_datasets

def main():
    seed_everything(SEED)

    if args.uuid:
        # Load everything from config file
        num_digits, checkpoint, true_batch_size, total_steps = load_from_config()

    else:
        total_steps = 0
        num_digits = args.num_digits_start
        checkpoint = args.checkpoint
        true_batch_size = args.batch_size
    
    if args.fast:
        true_batch_size = 16
        args.read_data_from = args.read_data_from.replace("/data/", "/fastdata/")

    model, tokenizer, trainer, wandb_logger = setup(true_batch_size, checkpoint)
    if model.hparams.batch_size == 0:
        raise Exception("Ran out of memory with batch size 1. Stopping")

    
    while True:
        trainer = setup_trainer(wandb_logger, true_batch_size)

        # Only log errors if we're not at global zero
        if trainer.is_global_zero:
            logging.getLogger().setLevel(logging.INFO)
        else:
            logging.getLogger().setLevel(logging.ERROR)
        
        try:
            logging.info("Starting training from checkpoint {} and {} digits".format(checkpoint, num_digits))
            if checkpoint:
                model.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint,
                                                                    torch_dtype=torch.bfloat16)
                                                            

                # Verify it has perfect accuracy up to num_digits-1. If args.fast, skip this check
                # perfect_acc, accs = utils.verify_accuracy(trainer, model, num_digits-1, type=args.type, batch_size=model.hparams.batch_size, flash=not args.no_flash)
                if False and not perfect_acc and not args.fast:
                
                    min_mistake = np.argmin(accs < .99)
                    message = "Checkpoint {} does not have perfect accuracy up to {} digits, failing at digit {}. Relearning that checkpoint. Accs {}. This is not intended behavior.".format(checkpoint, num_digits-1, min_mistake, accs)

                    num_digits = min_mistake
                    write_to_config(num_digits=num_digits, checkpoint=checkpoint, batch_size=true_batch_size, total_steps=total_steps, message=message)
            
            logging.info("Ready to begin training")
            train_loader, val_loader, _, _ = prepare_dataloaders(model, tokenizer, trainer, num_digits, true_batch_size, use_flash=not args.no_flash, checkpoint=checkpoint)

            if args.fast and trainer.is_global_zero:
                print("Training dataset")
                print_loader(train_loader, tokenizer)

                for i, val_loader_i in enumerate(val_loader):
                    print("Validation dataset {}".format(i))
                    print_loader(val_loader_i, tokenizer)
                
            trainer.fit(model, train_loader, val_loader)

            if args.generate_data or args.read_data_from:
                model_chpt_path = CHECKPOINT_DIR + "/model-selftrain-{}-digits.ckpt".format(num_digits)
            else:
                model_chpt_path = CHECKPOINT_DIR + "/model-supervised-{}-digits.ckpt".format(num_digits)
            
            if args.fast:
                model_chpt_path += ".fast"
            
            model.model.save_pretrained(model_chpt_path, from_pt=True)

            results = trainer.validate(model, val_loader)

        except Exception as e:
            if "CUDA out of memory" in str(e):
                message = "Ran out of memory with batch size {}. Trying again.".format(true_batch_size)
                write_to_config(num_digits=num_digits, checkpoint=checkpoint, batch_size=true_batch_size // 2, total_steps=total_steps, message=message)
                sys.exit(0)
            else:
                try:
                    # Try to save the model
                    if not args.fast:
                        print("Trying to save the model")
                        model_chpt_path = CHECKPOINT_DIR + "/model-failsafe.ckpt"
                        model.model.save_pretrained(model_chpt_path, from_pt=True)
                        print("Finished saving the model")
                except:
                    print("Failed to save the model")

                raise e
        
        logging.info("Successfully finished training on {} digits with batch size {}".format(num_digits, true_batch_size))

        # If more than 1 epoch, you used the whole trainset.
        num_training_examples_used = min(len(train_loader.dataset), trainer.global_step * true_batch_size)
        write_to_master_log("Training on {} digits used {} training examples and {} epochs".format(num_digits, num_training_examples_used, trainer.current_epoch))
        total_steps += trainer.global_step * true_batch_size
        start_idx = 2 if args.type == "decomp" else 1
        perfect_accuracy = all([results[0]["val_acc_{}_digits".format(i)] >= .99 for i in range(start_idx, num_digits + 1)])


        if perfect_accuracy or args.fast:
            message = "Reached perfect accuracy on {} digit addition in {} steps and batch size {}. Overall total steps {}".format(
                num_digits, trainer.global_step, true_batch_size, total_steps)

            num_digits += 1
            # Don't need to reset batch size. Probably will only get smaller

        else:
            message = "Failed to reach perfect accuracy on {} digit addition. Retrying".format(num_digits)

            write_to_master_log(message)

            if args.read_data_from:
                exit(100)
        
        write_to_config(num_digits=num_digits, checkpoint=model_chpt_path, batch_size=true_batch_size, total_steps=total_steps, message=message)

        if args.read_data_from:
            exit(0)

        # Don't reload the model
        checkpoint = None

def print_loader(loader, tokenizer):
    for datapoint in list(loader.dataset):
        ids = datapoint['labels'].tolist()

        # Remove padding
        if -100 in ids:
            ids = ids[:ids.index(-100)]

        print(tokenizer.decode(ids))

if __name__ == "__main__":
    logging.info(args)
    current_hash = git.Repo(search_parent_directories=True).head.object.hexsha
    logging.info("Current git hash: {}".format(current_hash))
    main()