import argparse
import sys

from datamodule_trainer import BooleandataModule
from model import BooleanModel
from common import *
import wandb
import os
from sklearn.model_selection import ParameterGrid
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
import numpy as np
import pytorch_lightning as pl
import time
from pytorch_lightning.loggers import WandbLogger


api_key = "cd0fe4d1288a46a5a239fe793a6ffe08edc9c33a" # Log in using the API key
wandb.login(key=api_key)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main(args, wandb_logger):
    # Set the random seed
    set_seed(args.seed)

    checkpoint_callback = ModelCheckpoint(

        dirpath=args.checkpoint_dir,
        save_top_k=1,
        every_n_epochs=100,
        save_on_train_epoch_end=False,
        save_last=True,
    )

    # Create a WandbLogger instance
    wandb_logger = wandb_logger

    learningrate_callback = LearningRateMonitor(logging_interval='step')

    # Instantiate the data module
    datamodule = BooleandataModule(model_type=args.model_type,
                                   composition=args.composition,
                                   model_name=args.model_name,
                                   max_input_len=args.max_input_len,
                                   max_output_len=args.max_output_len,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   task=args.task,
                                   path_train=args.path_train,
                                   path_val=args.path_val,
                                   path_test=args.path_test,
                                   method=args.method,
                                   tokenizer_type=args.tokenizer_type)

    # Instantiate the Lightning model with hyperparameters
    model = BooleanModel(
        task=args.task,
        model_type=args.model_type,
        composition=args.composition,
        model_name=args.model_name,
        warmup_steps=args.warmup_steps,
        lr=args.lr,
        max_input_len=args.max_input_len,
        num_beams=args.num_beams,
        max_num_steps=args.max_num_steps,
        method=args.method,
        num_layers=args.num_layers,
        d_model=args.d_model,
        d_ff=args.d_ff,
        num_heads=args.num_heads,
        tokenizer_type=args.tokenizer_type
    )

    # Initialize the trainer
    trainer = pl.Trainer(
        devices=args.num_gpus,
        accelerator="gpu",
        gradient_clip_val=0.5,
        max_epochs=args.epoch,
        accumulate_grad_batches=16,
        log_every_n_steps=5,
        check_val_every_n_epoch=args.val_epoch,
        callbacks=[learningrate_callback,checkpoint_callback],

        logger=wandb_logger,

    )
    #print(f"model_config, d_ff={model.d_ff},d_model={model.d_model},num_layers={model.num_layers},num_heads={model.num_heads}")
    # Train the model
    trainer.fit(model=model, datamodule=datamodule)
    #datamodule.teardown()

    #del datamodule


if __name__ == "__main__":
    # Create an argument parser
    parser = argparse.ArgumentParser(description="PyTorch Lightning Trainer")
    started_at = time.gmtime()
    # Add arguments
    parser.add_argument("--small", action='store_true')
    parser.add_argument("--supplement", action='store_true')
    parser.add_argument("--seed", type=int, default=42, help="set seed")
    parser.add_argument("--epoch", type=int, default=100, help="epoch")
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate for the model")
    parser.add_argument("--warmup_steps", type=int, default=1000, help="WarmUp step")
    parser.add_argument("--composition", action='store_true')
    parser.add_argument("--model_type", type=str, default='encoder', help="model type, encoder/seq2seq")
    parser.add_argument("--model_name", type=str, default='roberta-base', help="model name")
    parser.add_argument("--max_input_len", type=int, default=256, help="max input length")
    parser.add_argument("--num_beams", type=int, default=5, help="number of beam search")
    parser.add_argument("--max_num_steps", type=int, default=20, help="maximum number of recursive proof generation")
    parser.add_argument("--method", type=str, default='un-pretrained', help="pretrained/un-pretrained")
    parser.add_argument("--task", type=str, default='classification',
                        help="Choose task classification/one-shot/stepwise")
    parser.add_argument("--tokenizer_type", type=str, default='default',help="type of tokenizer..default/custom")
    parser.add_argument("--val_epoch", type=int, default=10, help="check_val_every_n_epoch")
    parser.add_argument("--num_gpus", type=int, default=1, help="num_gpus")
    parser.add_argument("--batch_size", type=int, default=64, help="batch_size")
    parser.add_argument("--num_workers", type=int, default=2, help="num_workers")
    parser.add_argument("--max_output_len", type=int, default=256, help="max_output_len")
    parser.add_argument("--group_name", type=str, default='large', help="wandb group_name")
    parser.add_argument("--checkpoint_dir", type=str, default='/projects/boolean/Checkpoint_epoch_100')
    parser.add_argument("--path_train", type=str,
                        default='/userhomes/Boolean/augmented_data/train3/train_depth_2_to_3.jsonl',
                        help="train dataset dir")
    parser.add_argument("--path_val", type=str,
                        default='/userhomes/Boolean/augmented_data/train3/new_valid_depth_3.jsonl',
                        help="val dataset dir")
    parser.add_argument("--path_test", type=str,
                        default='/userhomes/Boolean/augmented_data/train3/new_test_depth_3.jsonl',
                        help="test dataset dir")

    # Parse the command-line arguments
    args = parser.parse_args()
    # Call the main function with the parsed arguments
    if 'roberta' in args.model_name:

        args.num_layers = 12
        args.d_model = 768
        args.d_ff = 3072
        args.num_heads = 12

    elif 't5' in args.model_name:
        args.num_layers = 12
        args.d_model = 768
        args.d_ff = 3072
        args.num_heads = 12

    elif 'gpt2' in args.model_name:
        args.num_layers = 12
        args.d_model = 768
        args.d_ff = 3072
        args.num_heads = 12

    wandb.init(project='Boolean_new', ',name=f'{args.model_name}_{args.task}_{args.method}_{args.tokenizer_type}')
    wandb_logger = WandbLogger()
    main(args, wandb_logger)
    wandb.finish()
    """
    # Call the main function with the parsed arguments
    hyperparameters = {
        "num_layers": [2, 4, 6, 8],
        "d_model": [32, 64, 128],
        "d_ff": [256, 512, 1024],
        "num_heads": [4, 6, 8, ],
        "num_layers": [2, 4, 6, 8],
        # "d_model": [32,64,128,256,512,768],
        # "d_ff": [256,512,1024,2048,3072],
        # "num_heads":[4,6,8,10,12,14,16]
    }

    param_grid = ParameterGrid(hyperparameters)

    for params in list(param_grid):
        print(params)
        args.num_layers = params['num_layers']
        args.d_model = params['d_model']
        args.d_ff = params['d_ff']
        args.num_heads = params['num_heads']
        from pytorch_lightning.loggers import WandbLogger

        wandb.init(project='grid_search_boolean',
                                   name=f'{args.model_name}_{args.num_layers}_{args.d_model}_{args.d_ff}_{args.num_heads}_{args.num_layers}')
        wandb_logger = WandbLogger()
        main(args, wandb_logger)
        wandb.finish()
    """
"""
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main(args,wandb_logger):
    # Set the random seed
    set_seed(args.seed)

    checkpoint_callback = ModelCheckpoint(
        monitor='Classification_ACC_val',
        mode='max',
        dirpath='/projects/boolean//un-pretrained/train_2_3/t5-small_2layers/classification/valid_3',
        save_top_k=1,
        every_n_epochs=5,
        save_on_train_epoch_end=False,
        save_last=True,
    )

    # Create a WandbLogger instance
    wandb_logger = wandb_logger

    learningrate_callback=LearningRateMonitor(logging_interval='step')


    # Instantiate the data module
    datamodule = BooleandataModule(model_type=args.model_type,
        model_name = args.model_name,
        max_input_len = args.max_input_len,
        max_output_len = args.max_output_len,
        batch_size = args.batch_size,
        num_workers = args.num_workers,
        task=args.task,
        path_train = args.path_train,
        path_val = args.path_val,
        path_test = args.path_test,
        method=args.method)

    # Instantiate the Lightning model with hyperparameters
    model = BooleanModel(
        task=args.task,
        model_type=args.model_type,
        model_name=args.model_name,
        warmup_steps=args.warmup_steps,
        lr=args.lr,
        max_input_len=args.max_input_len,
        num_beams=args.num_beams,
        max_num_steps=args.max_num_steps,
        method=args.method,
        num_layers=args.num_layers,
        d_model=args.d_model,
        d_ff=args.d_ff,
        num_heads=args.num_heads,
    )

    # Initialize the trainer
    trainer = pl.Trainer(
        devices=1,
        accelerator="gpu",
        gradient_clip_val=0.5,
        max_epochs=100,  # warmup일때는 100으로, dynamic은 400, static은 500
        accumulate_grad_batches=16,
        log_every_n_steps=5,
        check_val_every_n_epoch=10, # auto-sample with warmup일때는 1로 하자
        callbacks=[checkpoint_callback,learningrate_callback],
        logger=wandb_logger,
    )

    # Train the model
    trainer.fit(model=model, datamodule=datamodule)

if __name__ == "__main__":
    # Create an argument parser
    parser = argparse.ArgumentParser(description="PyTorch Lightning Trainer")


    # Add arguments
    parser.add_argument("--seed", type=int, default=42, help="set seed")
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate for the model")
    parser.add_argument("--warmup_steps", type=int, default=1000, help="WarmUp step")
    parser.add_argument("--model_type", type=str, default='encoder', help="model type, encoder/encoder-decoder/decoder")
    parser.add_argument("--model_name", type=str, default='roberta-base', help="model name")
    parser.add_argument("--max_input_len", type=int, default=256, help="max input length")
    parser.add_argument("--num_beams", type=int, default=5, help="number of beam search")
    parser.add_argument("--max_num_steps", type=int, default=20, help="maximum number of recursive proof generation")
    parser.add_argument("--method", type=str, default='un-pretrained', help="Pretrained/Un-pretrained")
    parser.add_argument("--task", type=str, default='classification ', help="Choose task classification/one-shot/stepwise")
    parser.add_argument("--batch_size", type=int, default=20, help="batch_size")
    parser.add_argument("--num_workers", type=int, default=2, help="num_workers")
    parser.add_argument("--max_output_len", type=int, default=256, help="max_output_len")
    parser.add_argument("--path_train", type=str, default='/userhomes/Boolean/data/real_train_with_2_3/train_depth_2_3.jsonl', help="train dataset dir")
    parser.add_argument("--path_val", type=str, default='/userhomes/Boolean/data/real_train_with_2_3/new_valid_depth_3.jsonl', help="val dataset dir")
    parser.add_argument("--path_test", type=str, default='/userhomes/Boolean/data/real_train_with_2_3/new_valid_depth_3.jsonl', help="test dataset dir")

    # Parse the command-line arguments
    args = parser.parse_args()

    # Call the main function with the parsed arguments
    hyperparameters = {
         "num_layers": [2, 4, 6,8],
         "d_model": [32,64,128],
         "d_ff": [256,512,1024],
         "num_heads":[4,6,8,],
        "num_layers": [2, 4, 6,8],
                                                #"d_model": [32,64,128,256,512,768],
                                                #"d_ff": [256,512,1024,2048,3072],
                                                #"num_heads":[4,6,8,10,12,14,16]
    }

    param_grid=ParameterGrid(hyperparameters)

    for params in list(param_grid):
        print(params)
        args.num_layers=params['num_layers']
        args.d_model=params['d_model']
        args.d_ff=params['d_ff']
        args.num_heads=params['num_heads']
        from pytorch_lightning.loggers import WandbLogger

        wandb.init()
        wandb_logger = WandbLogger(project='grid_search_boolean',
                                   name=f'{args.model_name}_{args.num_layers}_{args.d_model}_{args.d_ff}_{args.num_heads}_{args.num_layers}')
        main(args,wandb_logger)

"""