# Utils must be imported first before the transformers library
import utils


# External Modules
from datasets import load_dataset
import time
from torch.utils.data import DataLoader
from IPython import embed
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DeepSpeedStrategy
from lightning_fabric.utilities.seed import seed_everything
from peft import get_peft_model, LoraConfig, TaskType

import wandb
import random
import copy
import torch
import logging
import numpy as np
import os
import uuid
from datetime import datetime
from pytz import timezone
import socket
import argparse
import git
import gc
import json
import sys

# Internal Modules
import generate_data
from models import AdditionFlanT5

NUM_DEVICES = torch.cuda.device_count()
DEFAULT_BATCH_SIZE = 8 * NUM_DEVICES

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--checkpoint", type=str, default=None) 
argument_parser.add_argument("--num_digits_start", type=int, default=3)
argument_parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE)
argument_parser.add_argument("--size", type=str, default="small")
argument_parser.add_argument("--uuid", type=str, default=None)
argument_parser.add_argument("--seed", type=int, default=0)
argument_parser.add_argument("--type", type=str, default="decomp", help="Currently implemented types are 'decomp', 'full', and 'remove'")

argument_parser.add_argument("--debug", action="store_true", default=False)
argument_parser.add_argument("--fast", action="store_true", default=False)
argument_parser.add_argument("--no_flash", action="store_true", default=False)
argument_parser.add_argument("--alpaca", action="store_true", default=False)
argument_parser.add_argument("--wandb_id", type=str, default=None)

args = argument_parser.parse_args()

tz = timezone('EST')
if args.uuid is None:
    UUID = str(uuid.uuid4())
else:
    UUID = args.uuid

host_name = socket.gethostname()
CHECKPOINT_DIR = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/checkpoints/supervised/".format(host_name) + UUID
MASTER_LOCATION = os.path.join(CHECKPOINT_DIR, "master.log")
print(MASTER_LOCATION)
CONFIG_FILE = CHECKPOINT_DIR + "/config.json"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

NUM_TRAINING_EXAMPLES = 128
SAVE_FREQUENCY = 70 * NUM_DEVICES if args.num_digits_start == 3 else 25 * NUM_DEVICES
TRAINING_TYPE = "english"
VAL_SIZE = 1
MAX_EPOCHS = 1
SEED = args.seed
# MODEL = "google/flan-t5-{}".format(args.size)
MODEL = "google/byt5-{}".format(args.size)

if args.fast:
    NUM_TRAINING_EXAMPLES = 10
    SAVE_FREQUENCY = 2
    VAL_SIZE = 2
    MAX_EPOCHS = 2

if args.debug:
    logging.basicConfig(
        stream=sys.stdout, 
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')
else:
    logging.basicConfig(
        filename="logs/supervised_train/{}_supervisedtrain.log".format(UUID),
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S')
    print("Logging to {}".format(logging.root.handlers[0].baseFilename))

os.environ['WANDB_DISABLED'] = 'true'
def setup(true_batch_size, checkpoint):
    wandb.login()
    
    if args.alpaca:
        tokenizer = LlamaTokenizer.from_pretrained("chainyo/alpaca-lora-7b")
        tokenizer.add_special_tokens({'pad_token': '<PAD>'})
    else:
        tokenizer = AutoTokenizer.from_pretrained(MODEL)

    tokenizer.add_special_tokens({"additional_special_tokens": ["<scratchpad>", "</scratchpad>", "|||"] + tokenizer.additional_special_tokens})
    logging.info("Loading model and tokenizer")

    wandb_logger = WandbLogger(name=str(datetime.now(tz)), project='SupervisedAdditionRLDeepSpeedEra', id=args.wandb_id)
    write_to_master_log("Wandb ID: {}".format(wandb_logger.experiment.id))
    
    model = setup_model(wandb_logger, tokenizer, true_batch_size, checkpoint)
    trainer = setup_trainer(wandb_logger, true_batch_size)
    logging.info("Loaded model and tokenizer")

    return model, tokenizer, trainer, wandb_logger

def setup_model(wandb_logger, tokenizer, true_batch_size, checkpoint):
    model = AdditionFlanT5(type=args.type)

    if not checkpoint:
        if args.alpaca:
            base_model = LlamaForCausalLM.from_pretrained(
                "chainyo/alpaca-lora-7b",
                load_in_8bit=True,
                torch_dtype=torch.float16,
                device_map="auto",
            )
            nn_model = base_model

        else:
            nn_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL,
                                                            torch_dtype=torch.bfloat16,
                                                            )
    else:
        nn_model = None

    model.initialize(model=nn_model, tokenizer=tokenizer, wandb_logger=wandb_logger)
    model.hparams.batch_size = true_batch_size // NUM_DEVICES
    return model

def setup_trainer(wandb_logger, true_batch_size):

    early_stop_callback = EarlyStopping(
        monitor="min_val_acc",
        patience=100,
        verbose=True,
        min_delta=0.01,
        mode='max',
        stopping_threshold=0.999
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    # checkpoint_callback = ModelCheckpoint(dirpath=CHECKPOINT_DIR, save_top_k=1, every_n_train_steps=SAVE_FREQUENCY, monitor="loss", filename="supervised-{epoch:02d}-{global_step}",)
    # `logging.info("Saving model at {}".format(checkpoint_callback.dirpath))

    strategy=DeepSpeedStrategy(
        stage=2,
        offload_optimizer=True,
        # logging_level = logging.INFO,
    )

    precision = "16-mixed" if args.alpaca else "bf16-mixed"
    trainer = Trainer(accelerator="gpu",
                      devices=NUM_DEVICES, 
                      strategy=strategy,
                      default_root_dir="/home/azureuser/cloudfiles/code/Users/DIRNAME/addition/logs/supervised_train/",
                      logger=wandb_logger, 
                      val_check_interval = SAVE_FREQUENCY,
                      precision=precision,

                      # auto_scale_batch_size='power',
                      # auto_lr_find=True,
                      enable_checkpointing=False,
                      callbacks=[lr_monitor, early_stop_callback],#, checkpoint_callback],
                      accumulate_grad_batches=1,

                      max_epochs=MAX_EPOCHS,
                      max_steps=-1)
    
    trainer.accumulate_grad_batches = DEFAULT_BATCH_SIZE // true_batch_size
    return trainer

def generate_trainset(num_digits, tokenizer, use_flash):
    full_dataset = generate_data.MultiDigitAdditionDataset(num_examples=NUM_TRAINING_EXAMPLES + VAL_SIZE, 
                                                           primary_num_digits=num_digits,
                                                           num_old_examples=int(0.1*NUM_TRAINING_EXAMPLES),
                                                           dataset_type=TRAINING_TYPE,
                                                           tokenizer=tokenizer,
                                                           type=args.type,
                                                           force_min_number_examples=int(0.3*NUM_TRAINING_EXAMPLES),
                                                           use_flash=use_flash)
    val_datasets = full_dataset.split(VAL_SIZE)

    return full_dataset, val_datasets

def load_from_config():

    # Default values for num_digits, checkpoint, batch_size, total_steps
    if not os.path.exists(CONFIG_FILE):
        return args.num_digits_start, None, DEFAULT_BATCH_SIZE, 0

    with open(CONFIG_FILE, 'r') as f:
        config = json.load(f)
    
    num_digits = config['num_digits']
    checkpoint = config['checkpoint']
    batch_size = config['batch_size']
    total_steps = config['total_steps']

    logging.info("Loaded config {}".format(config))

    return num_digits, checkpoint, batch_size, total_steps

# Write to config file and restart the program
def write_to_config(*, num_digits, checkpoint, batch_size, total_steps, message):

    config_json = {'num_digits': num_digits, 'checkpoint': checkpoint, 'batch_size': batch_size, 'total_steps': total_steps}
    with open(CONFIG_FILE, 'w') as f:
        json.dump(config_json, f)
    
    logging.info(message)
    write_to_master_log(message + "\nWrote to config file: {}".format(config_json))

def write_to_master_log(message):

    with open(MASTER_LOCATION, "a+") as w:
        w.write(message + "\n")

def prepare_dataloaders(model, tokenizer, trainer, num_digits, true_batch_size, use_flash):
    model.set_cur_num_digits(num_digits)
    train_dataset, val_datasets = generate_trainset(num_digits, tokenizer, use_flash=use_flash)
    
    logging.info("Example of training data")
    logging.info(tokenizer.decode(train_dataset[0]['input_ids']))
    logging.info(tokenizer.decode(train_dataset[0]['labels'][train_dataset[0]['labels']>=0]))
    logging.info(str(train_dataset[0]['numerical_answer']))

    trainer.val_check_interval = min(SAVE_FREQUENCY, len(train_dataset) // true_batch_size)

    if args.fast:
        trainer.val_check_interval = 1
    
    logging.info("Trying to learn {} digits with true batch size {}".format(num_digits, true_batch_size))
    train_loader = DataLoader(train_dataset, batch_size=model.hparams.batch_size, shuffle=True, pin_memory=True, num_workers=NUM_DEVICES*4)
    val_loader = [DataLoader(val_dataset, batch_size=model.hparams.batch_size, shuffle=False, pin_memory=True, num_workers=NUM_DEVICES) for val_dataset in val_datasets]
    return train_loader, val_loader, train_dataset, val_datasets

def main():
    start_time = time.time()
    seed_everything(SEED)

    if False:
        # Load everything from config file
        num_digits, checkpoint, true_batch_size, total_steps = load_from_config()

    else:
        total_steps = 0
        num_digits = args.num_digits_start
        checkpoint = args.checkpoint
        true_batch_size = args.batch_size

    model, tokenizer, trainer, wandb_logger = setup(true_batch_size, checkpoint)
    if model.hparams.batch_size == 0:
        raise Exception("Ran out of memory with batch size 1. Stopping")
    
    print("Finished initializing model and trainer in time {}".format(time.time() - start_time))
    start_time = time.time()

    
    while True:
        trainer = setup_trainer(wandb_logger, true_batch_size)
        
        try:
            logging.info("Starting training from checkpoint {} and {} digits".format(checkpoint, num_digits))
            if checkpoint:
                model.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

                # Verify it has perfect accuracy up to num_digits-1. If args.fast, skip this check
                perfect_acc, accs = utils.verify_accuracy(trainer, model, num_digits-1, type=args.type, batch_size=model.hparams.batch_size, flash=not args.no_flash)
                if not perfect_acc and not args.fast:
                
                    min_mistake = np.argmin(accs < .99)
                    message = "Checkpoint {} does not have perfect accuracy up to {} digits, failing at digit {}. Relearning that checkpoint. Accs {}. This is not intended behavior.".format(checkpoint, num_digits-1, min_mistake, accs)

                    num_digits = min_mistake
                    write_to_config(num_digits=num_digits, checkpoint=checkpoint, batch_size=true_batch_size, total_steps=total_steps, message=message)
            
            logging.info("Ready to begin training")
            train_loader, _, _, _ = prepare_dataloaders(model, tokenizer, trainer, num_digits, true_batch_size, use_flash=not args.no_flash)
            print("Length of trainloader is {}".format(len(train_loader)))
            trainer.fit(model, train_loader)

            model_chpt_path = CHECKPOINT_DIR + "/model-supervised-{}-digits.ckpt".format(num_digits)
            # model.model.save_pretrained(model_chpt_path, from_pt=True)

            # results = trainer.validate(model, val_loader)

        except Exception as e:
            if "CUDA out of memory" in str(e):
                message = "Ran out of memory with batch size {}. Trying again.".format(true_batch_size)
                write_to_config(num_digits=num_digits, checkpoint=checkpoint, batch_size=true_batch_size // 2, total_steps=total_steps, message=message)
                sys.exit(0)
            else:
                try:
                    # Try to save the model
                    print("Trying to save the model")
                    model_chpt_path = CHECKPOINT_DIR + "/model-failsafe.ckpt"
                    model.model.save_pretrained(model_chpt_path, from_pt=True)
                    print("Finished saving the model")
                except:
                    pass

                raise e
        
        logging.info("Successfully finished training on {} digits with batch size {}".format(num_digits, true_batch_size))

        total_steps += trainer.global_step * true_batch_size
        start_idx = 2 if args.type == "decomp" else 1
        # perfect_accuracy = all([results[0]["val_acc_{}_digits".format(i)] >= .99 for i in range(start_idx, num_digits + 1)])


        if False or args.fast:
            message = "Reached perfect accuracy on {} digit addition in {} steps and batch size {}. Overall total steps {}".format(
                num_digits, trainer.global_step, true_batch_size, total_steps)

            # Don't need to reset batch size. Probably will only get smaller

        else:
            message = "Failed to reach perfect accuracy on {} digit addition. Retrying".format(num_digits)
        
        write_to_config(num_digits=num_digits, checkpoint=model_chpt_path, batch_size=true_batch_size, total_steps=total_steps, message=message)

        # Don't reload the model
        checkpoint = None

        write_to_master_log("Finished training on BS {} with time {}".format(true_batch_size, time.time() - start_time))
        start_time = time.time()

        true_batch_size *= 2

if __name__ == "__main__":
    logging.info(args)
    current_hash = git.Repo(search_parent_directories=True).head.object.hexsha
    logging.info("Current git hash: {}".format(current_hash))
    main()