import os, sys
import time
import math
import torch
import pickle
import logging
import argparse

import numpy as np
import transformers
import pytorch_lightning

from tqdm import tqdm
from termcolor import colored

from utils import GLUEDataModule, SuperGLUEDataModule, GLUE_TASKS, SUPERGLUE_TASKS
from llm_module import GLUETransformer

logger = logging.getLogger(__name__)

def parse_arguments():
    parser = argparse.ArgumentParser(description='Argument Parser for Model Fine-tuning')

    parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
    parser.add_argument('--samplesize', type=int, default=1024, help='Training data sample size')
    parser.add_argument('--samplesize_validation', type=int, default=128, help='Validation data sample size')
    parser.add_argument('--model_name', type=str, default='DistilBert', help='Name of the pre-trained model')
    parser.add_argument('--task', type=str, default='mnli', help='Task for model training')
    parser.add_argument('--full_parameter', action='store_true', help='True for full parameter fine-tuning')
    parser.add_argument('--algorithm', type=str, default='FO', help='Algorithm to use ("FO", "ZO", "ZOSVRG")')
    parser.add_argument('--q', type=int, default=2, help='q parameter used only for ZO-SVRG')
    parser.add_argument('--batchsize', type=int, default=64, help='Batch size for training')
    parser.add_argument('--batchsize_limit', type=int, default=64, help='Max batch size to be used to avoid memory error')
    parser.add_argument('--max_seq_length', type=int, default=256, help='Max sequence length for inputs')

    parser.add_argument('--anneal', type=float, default=1.5, help='Annealing parameter')
    parser.add_argument('--lr', type=float, default=2e-3, help='Learning rate')
    parser.add_argument('--device', type=int, default=0, help='GPU Number')
    parser.add_argument('--results', type=str, default='results_demo', help='Name of folder to store results')
    parser.add_argument('--lr_mezosvrg_mb', type=float, default=1e-6, help='Mini-batch learning rate for MeZO-SVRG')
    parser.add_argument('--perturbation_scale', type=float, default=1e-3, help='Perturbation scale for SPSA estimators')
    parser.add_argument('--soft_prompt', action='store_true', help='True for using soft prompt')
    parser.add_argument('--adaptivity', type=float, default=1e-8, help='the level of adaptivity')
    parser.add_argument('--save_ckpt', action='store_true', help='Save checkpoint')
    parser.add_argument('--save_every', type=int, default=10, help='Save checkpoint every n epochs')
    parser.add_argument('--eval_every', type=int, default=1, help='Evaluate on validation set every n epochs')
    
    parser.add_argument('--fp16', action='store_true', help='Use FP16 precision')
    parser.add_argument('--bf16', action='store_true', help='Use BF16 precision')
    parser.add_argument('--early_stopping', action='store_true', help='Enable early stopping')
    parser.add_argument('--patience', type=int, default=10, help='Early stopping patience;\
                        if the validation accuracy does not improve for this number of evaluations, stop training')
    parser.add_argument('--logging', type=str, default="wandb", help="Choose logging method; either wandb or tensorboard or none")

    parser.add_argument('--low_bit_adam', type=int, default=0, help='Use Adam with quantized states; options: 4 or 8')
    parser.add_argument('--trial', type=int, default=0, help='Trial number')
    parser.add_argument('--init_seed', type=int, default=None, help='Random seed for model initialization')

    args = parser.parse_args()
    return args


def save_pickle(data, folder, filename):
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, filename)
    with open(filepath, 'wb') as file:
        pickle.dump(data, file)
    
def get_data_module(args):
    if args.task in GLUE_TASKS:
        dm_class = GLUEDataModule
    elif args.task in SUPERGLUE_TASKS:
        dm_class = SuperGLUEDataModule
    else:
        raise ValueError(f"Task {args.task} is not supported")

    dm = dm_class(
        model_name_or_path=args.model_name,
        task_name=args.task,
        max_seq_length=args.max_seq_length,
        sample_size=args.samplesize,
        train_batch_size=args.batchsize,
        validation_sample_size=args.samplesize_validation,
        eval_batch_size=args.batchsize,
        soft_prompt=args.soft_prompt,
        hf_token=os.getenv('HF_TOKEN')
    )
    dm.setup(stage='fit')

    return dm

def get_model(args, dm):
    transformer = GLUETransformer(
        model_name_or_path=args.model_name,
        num_labels=dm.num_labels,
        eval_splits=dm.eval_splits,
        task_name=dm.task_name,
        learning_rate=args.lr_mezosvrg_mb if 'SVRG' in args.algorithm else args.lr,
        learning_rate_aux=args.lr if 'SVRG' in args.algorithm else None,
        q=args.q,
        lr_anneal=args.anneal,
        full_parameter=args.full_parameter,
        batchsize_limit=args.batchsize_limit,
        zero_order_eps=args.perturbation_scale,
        soft_prompt=args.soft_prompt,
        use_SGD=True if 'SGD' in args.algorithm else False,
        logger_type=args.logging,
        hf_token=os.getenv('HF_TOKEN'),
        model_init_seed=args.init_seed,
    )
    return transformer

def finetune_ZO(args, tb_writer=None):
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    dm = get_data_module(args)

    train_dataloader = dm.train_dataloader()
    val_dataloader = dm.val_dataloader()
    
    # Initializing Transformer
    transformer = get_model(args, dm)
    
    if args.bf16:
        transformer.to(torch.bfloat16)
    model = transformer.model
    model.to(device)

    start_epoch = 0
    
    transformer.configure_params()

    transformer.model.eval()
    if args.early_stopping:
        best_val_acc = 0
        patience_counter = 0

    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        logging.info(f'ZO, Epoch {epoch+1}')

        if (epoch + 1) % args.eval_every == 0:
            # validation loop
            logging.info('Validation Loop')
            val_loss_sums, val_correct, total_val_samples = [], 0, 0
            for _, batch in enumerate(tqdm(val_dataloader)):
                b = {}
                for k, v in batch.items():
                    b[k] = v.to(device)
                
                loss_sum, correct, total_samples = transformer.forward_ZO_val(b)
                val_loss_sums.append(loss_sum.item())
                val_correct += correct
                total_val_samples += total_samples

            val_loss_mean = np.sum(val_loss_sums) / total_val_samples
            val_acc = val_correct / total_val_samples
            logging.info(f"val_loss: {val_loss_mean}, val_acc: {val_acc}")

            if args.logging == "tensorboard":
                tb_writer.add_scalar("val_loss", val_loss_mean, transformer.state.global_training_steps)
                tb_writer.add_scalar("val_acc", val_acc, transformer.state.global_training_steps)
            else:
                wandb.log({'val_loss': val_loss_mean, 'val_acc': val_acc}, step=transformer.state.global_training_steps)
        
            if args.early_stopping:
                if val_acc >= best_val_acc:
                    best_val_acc = val_acc
                    patience_counter = 0
                    if args.logging == "tensorboard":
                        tb_writer.add_scalar("best_val_acc", best_val_acc, transformer.state.global_training_steps)
                    else:
                        wandb.log({'best_val_acc': best_val_acc}, step=transformer.state.global_training_steps)
                else:
                    patience_counter += 1
                    if patience_counter >= args.patience:
                        logging.info(f"Early stopping at epoch {epoch+1}. Best validation accuracy: {best_val_acc}")
                        if args.logging == "tensorboard":
                            tb_writer.add_scalar("best_val_acc", best_val_acc, transformer.state.global_training_steps)
                        else:
                            wandb.log({'best_val_acc': best_val_acc}, step=transformer.state.global_training_steps)
                        break

        # training loop
        logging.info("Training Loop")
        for _, batch in enumerate(tqdm(train_dataloader)):
            b = {}
            for k, v in batch.items():
                b[k] = v.to(device)
            transformer.training_step_ZO(model, b, tb_writer=tb_writer)
    
    if args.logging == "tensorboard" and tb_writer is not None:
        tb_writer.close()

def finetune_ZO_SVRG(args, tb_writer=None):
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    
    dm = get_data_module(args)
    train_mb_dataloader = dm.train_dataloader()
    train_dataloader = dm.train_full_dataloader()
    val_dataloader = dm.val_dataloader()
    
    transformer = get_model(args, dm)
    
    if args.bf16:
        transformer.to(torch.bfloat16)
    model = transformer.model
    model.to(device)
    
    start_epoch = 0
    transformer.configure_params()

    total_batches = len(train_mb_dataloader)    
    transformer.model.eval()
    if args.early_stopping:
        best_val_acc = 0
        patience_counter = 0

    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        logging.info(f'ZOSVRG, Epoch {epoch+1}')
        
        if (epoch + 1) % args.eval_every == 0:
            # validation loop
            logging.info('Validation Loop')
            val_loss_sums, val_correct, total_val_samples = [], 0, 0
            for _, batch in enumerate(tqdm(val_dataloader)):
                b = {}
                for k, v in batch.items():
                    b[k] = v.to(device)
                
                loss_sum, correct, total_samples = transformer.forward_ZO_val(b)
                val_loss_sums.append(loss_sum.item())
                val_correct += correct
                total_val_samples += total_samples

            val_loss_mean = np.sum(val_loss_sums) / total_val_samples
            val_acc = val_correct / total_val_samples
            logging.info(f"val_loss: {val_loss_mean}, val_acc: {val_acc}")

            if args.logging == "tensorboard":
                tb_writer.add_scalar("val_loss", val_loss_mean, transformer.state.global_training_steps)
                tb_writer.add_scalar("val_acc", val_acc, transformer.state.global_training_steps)
            else:
                wandb.log({'val_loss': val_loss_mean, 'val_acc': val_acc}, step=transformer.state.global_training_steps)
        
            if args.early_stopping:
                if val_acc >= best_val_acc:
                    best_val_acc = val_acc
                    patience_counter = 0
                    if args.logging == "tensorboard":
                        tb_writer.add_scalar("best_val_acc", best_val_acc, transformer.state.global_training_steps)
                    else:
                        wandb.log({'best_val_acc': best_val_acc}, step=transformer.state.global_training_steps)
                else:
                    patience_counter += 1
                    if patience_counter >= args.patience:
                        logging.info(f"Early stopping at epoch {epoch+1}. Best validation accuracy: {best_val_acc}")
                        if args.logging == "tensorboard":
                            tb_writer.add_scalar("best_val_acc", best_val_acc, transformer.state.global_training_steps)
                        else:
                            wandb.log({'best_val_acc': best_val_acc}, step=transformer.state.global_training_steps)
                        break

        # training loop - random permutation
        logging.info('Training Loop : Minibatch sampling using random permutation')
        for i, batch in enumerate(tqdm(train_mb_dataloader)):
            # get full batch every q steps
            curr_iteration = epoch * total_batches + i
            if curr_iteration % args.q == 0:
                logging.debug('Full-Batch Iteration')
                batch = next(iter(train_dataloader))
            else:
                logging.debug('Mini-Batch Iteration')
            
            b = {}
            for k, v in batch.items():
                b[k] = v.to(device)
            transformer.training_step_MeZO_SVRG(model, b, epoch, i, total_batches, tb_writer=tb_writer)
    
    if args.logging == "tensorboard" and tb_writer is not None:
        tb_writer.close()
    
def finetune_ZO_Adam(args, tb_writer=None):
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    dm = get_data_module(args)
    
    train_dataloader = dm.train_dataloader()
    val_dataloader = dm.val_dataloader()
    
    transformer = get_model(args, dm)
    if args.bf16:
        transformer.to(torch.bfloat16)
    model = transformer.model
    model.to(device)
    
    if args.low_bit_adam == 8:
        import bitsandbytes as bnb
        optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr, eps=args.adaptivity, weight_decay=0)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=args.adaptivity)

    start_epoch = 0
    transformer.configure_params()

    transformer.model.eval()
    if args.early_stopping:
        best_val_acc = 0
        patience_counter = 0
    
    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        logging.info(f'ZOAdam, Epoch {epoch+1}')

        if (epoch + 1) % args.eval_every == 0:
            # validation loop
            logging.info('Validation Loop')
            val_loss_sums, val_correct, total_val_samples = [], 0, 0
            for _, batch in enumerate(tqdm(val_dataloader)):
                b = {}
                for k, v in batch.items():
                    b[k] = v.to(device)
                
                loss_sum, correct, total_samples = transformer.forward_ZO_val(b)
                val_loss_sums.append(loss_sum.item())
                val_correct += correct
                total_val_samples += total_samples

            val_loss_mean = np.sum(val_loss_sums) / total_val_samples
            val_acc = val_correct / total_val_samples
            logging.info(f"val_loss: {val_loss_mean}, val_acc: {val_acc}")

            if args.logging == "tensorboard":
                tb_writer.add_scalar("val_loss", val_loss_mean, transformer.state.global_training_steps)
                tb_writer.add_scalar("val_acc", val_acc, transformer.state.global_training_steps)
            else:
                wandb.log({'val_loss': val_loss_mean, 'val_acc': val_acc}, step=transformer.state.global_training_steps)
        
            if args.early_stopping:
                if val_acc >= best_val_acc:
                    best_val_acc = val_acc
                    patience_counter = 0
                    if args.logging == "tensorboard":
                        tb_writer.add_scalar("best_val_acc", best_val_acc, transformer.state.global_training_steps)
                    else:
                        wandb.log({'best_val_acc': best_val_acc}, step=transformer.state.global_training_steps)
                else:
                    patience_counter += 1
                    if patience_counter >= args.patience:
                        logging.info(f"Early stopping at epoch {epoch+1}. Best validation accuracy: {best_val_acc}")
                        if args.logging == "tensorboard":
                            tb_writer.add_scalar("best_val_acc", best_val_acc, transformer.state.global_training_steps)
                        else:
                            wandb.log({'best_val_acc': best_val_acc}, step=transformer.state.global_training_steps)
                        break

        # training loop
        logging.info("Training Loop")
        for _, batch in enumerate(tqdm(train_dataloader)):
            b = {}
            for k, v in batch.items():
                b[k] = v.to(device)
            transformer.training_step_ZO_Adam(model, b, optimizer, tb_writer=tb_writer)
    
    if args.logging == "tensorboard" and tb_writer is not None:
        tb_writer.close() 

if __name__ == "__main__":
    args = parse_arguments()

    trimmed_model_name = args.model_name.split('/')[-1]
    num_steps = math.ceil(args.samplesize/args.batchsize) * args.epochs
    args.run_name = f'{trimmed_model_name}_{args.task}_{args.algorithm}_lr{args.lr:.0e}'
    
    if args.algorithm == 'ZOSVRG':
        args.run_name += f'_lr_mb_{args.lr_mezosvrg_mb:.0e}'
    elif args.algorithm == 'ZOAdam':
        args.run_name += f'_eps_{args.adaptivity:.0e}'
        if args.low_bit_adam == 4 or args.low_bit_adam == 8:
            args.run_name += f'_{args.low_bit_adam}bit'
    
    args.run_name += f'_bsz{args.batchsize}_steps{num_steps}'
    if args.init_seed is not None:
        args.run_name += f'_init_seed_{args.init_seed}'
    args.run_name += f'_trial_{args.trial}'

    args.output_dir = os.path.join('results', trimmed_model_name, args.task, args.run_name)

    ### wandb init ###
    if args.logging == "wandb":
        import wandb
        # wandb.login(key=os.environ["WANDB_KEY"])
        writer = None
    elif args.logging == "tensorboard":
        from tensorboardX import SummaryWriter

    print(f"Start from scratch: {args.output_dir}")
    
    if args.logging == "wandb":
        wandb.init(project='zo-a3dam', name=args.run_name, config=args)
    elif args.logging == "tensorboard":
        if 'FO' in args.algorithm:
            writer = pytorch_lightning.loggers.TensorBoardLogger(os.path.join('runs', args.run_name))
        else:
            writer = SummaryWriter(logdir=os.path.join('runs', args.run_name))
    os.makedirs(args.output_dir, exist_ok=True)

    ### logging setup ###
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.propagate = False

    fmt = colored('[%(asctime)s %(name)s]', 'green') + \
            colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'

    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(
        logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
    logger.addHandler(console_handler)

    # create file handlers
    file_handler = logging.FileHandler(os.path.join(args.output_dir, f'log_rank0_{int(time.time())}.txt'), mode='a')
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
    logger.addHandler(file_handler)

    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()
    transformers.utils.logging.set_verbosity(logging.INFO)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()
    
    logger.info(f"Training Arguments {args}")

    if args.algorithm == 'ZO':
        finetune_ZO(args, writer)
    elif args.algorithm == 'ZOSVRG':
        finetune_ZO_SVRG(args, writer)    
    elif args.algorithm == 'ZOAdam':
        finetune_ZO_Adam(args, writer)
