import argparse
import sys

sys.path.append('/home/minsu/bias/Boolean/')
from datamodule_trainer import BooleandataModule
from model import BooleanModel
from common import *
import wandb
import os
from sklearn.model_selection import ParameterGrid
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
import numpy as np
import pytorch_lightning as pl
import time
from pytorch_lightning.loggers import WandbLogger

api_key = "cd0fe4d1288a46a5a239fe793a6ffe08edc9c33a" # Log in using the API key
wandb.login(key=api_key)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def main(args, wandb_logger):
    # Set the random seed
    set_seed(args.seed)
    """
    checkpoint_callback = ModelCheckpoint(
        monitor='Classification_ACC_val',
        mode='max',
        dirpath='/projects/boolean/Minsu_Boolean_checkpoint/un-pretrained/train_2_3/t5-small_2layers/classification/valid_3',
        save_top_k=1,
        every_n_epochs=5,
        save_on_train_epoch_end=False,
        save_last=True,
    )
    """

    # Create a WandbLogger instance
    wandb_logger = wandb_logger

    learningrate_callback = LearningRateMonitor(logging_interval='step')

    # Instantiate the data module
    datamodule = BooleandataModule(model_type=args.model_type,
                                   #composition=args.composition,
                                   model_name=args.model_name,
                                   max_input_len=args.max_input_len,
                                   max_output_len=args.max_output_len,
                                   batch_size=args.batch_size,
                                   num_workers=args.num_workers,
                                   task=args.task,
                                   path_train=args.path_train,
                                   path_val=args.path_val,
                                   path_test=args.path_test,
                                   method=args.method,
                                   tokenizer_type=args.tokenizer_type)

    # Instantiate the Lightning model with hyperparameters
    model = BooleanModel(
        task=args.task,
        model_type=args.model_type,
        model_name=args.model_name,
        warmup_steps=args.warmup_steps,
        #=args.composition,
        lr=args.lr,
        max_input_len=args.max_input_len,
        num_beams=args.num_beams,
        max_num_steps=args.max_num_steps,
        method=args.method,
        num_layers=args.num_layers,
        d_model=args.d_model,
        d_ff=args.d_ff,
        num_heads=args.num_heads,
        tokenizer_type=args.tokenizer_type
    )

    # Initialize the trainer
    trainer = pl.Trainer(

        logger=wandb_logger,

    )
    #print(f"model_config, d_ff={model.d_ff},d_model={model.d_model},num_layers={model.num_layers},num_heads={model.num_heads}")
    # Train the model
    model_test = BooleanModel.load_from_checkpoint(args.checkpoint_dir)
    trainer.test(model=model_test, datamodule=datamodule)
    #datamodule.teardown()

    #del datamodule


if __name__ == "__main__":
    # Create an argument parser
    parser = argparse.ArgumentParser(description="PyTorch Lightning Trainer")
    started_at = time.gmtime()
    # Add arguments
    parser.add_argument("--small", action='store_true')
    parser.add_argument("--supplement", action='store_true')
    parser.add_argument("--seed", type=int, default=42, help="set seed")
    parser.add_argument("--epoch", type=int, default=100, help="epoch")
    #parser.add_argument("--composition", action='store_true')
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate for the model")
    parser.add_argument("--warmup_steps", type=int, default=1000, help="WarmUp step")
    parser.add_argument("--model_type", type=str, default='encoder', help="model type, encoder/seq2seq")
    parser.add_argument("--model_name", type=str, default='roberta-base', help="model name")
    parser.add_argument("--max_input_len", type=int, default=256, help="max input length")
    parser.add_argument("--num_beams", type=int, default=5, help="number of beam search")
    parser.add_argument("--max_num_steps", type=int, default=20, help="maximum number of recursive proof generation")
    parser.add_argument("--method", type=str, default='un-pretrained', help="pretrained/un-pretrained")
    parser.add_argument("--task", type=str, default='classification',
                        help="Choose task classification/one-shot/stepwise")
    parser.add_argument("--tokenizer_type", type=str, default='default',help="default/custom")
    parser.add_argument("--val_epoch", type=int, default=10, help="check_val_every_n_epoch")
    parser.add_argument("--num_gpus", type=int, default=1, help="num_gpus")
    parser.add_argument("--batch_size", type=int, default=64, help="batch_size")
    parser.add_argument("--num_workers", type=int, default=2, help="num_workers")
    parser.add_argument("--max_output_len", type=int, default=256, help="max_output_len")
    parser.add_argument("--group_name", type=str, default='large', help="wandb group_name")
    parser.add_argument("--checkpoint_dir",type=str,default='/projects/boolean/Checkpoint_epoch_100/Atomic/GPT2_Medium_pretrained_Oneshot/last.ckpt')
    parser.add_argument("--path_train", type=str,
                        default='/userhomes/minsu/Boolean/augmented_data/train3/train_depth_2_to_3.jsonl',
                        help="train dataset dir")
    parser.add_argument("--path_val", type=str,
                        default='/userhomes/minsu/Boolean/augmented_data/train3/new_valid_depth_3.jsonl',
                        help="val dataset dir")
    parser.add_argument("--path_test", type=str,
                        default='/userhomes/minsu/Boolean/augmented_data/atomic/atomic.jsonl',
                        help="test dataset dir")

    # Parse the command-line arguments
    args = parser.parse_args()
    # Call the main function with the parsed arguments
    if 'roberta' in args.model_name:

        args.num_layers = 12
        args.d_model = 768
        args.d_ff = 3072
        args.num_heads = 12

    elif 't5' in args.model_name:
        args.num_layers = 12
        args.d_model = 768
        args.d_ff = 3072
        args.num_heads = 12

    elif 'gpt2' in args.model_name:
        args.num_layers = 12
        args.d_model = 768
        args.d_ff = 3072
        args.num_heads = 12

    wandb.init(project='Boolean_new', entity='kimminsu',name=f'Test_{args.model_name}_{args.task}_{args.method}_{args.tokenizer_type}')
    wandb_logger = WandbLogger()
    main(args, wandb_logger)
    wandb.finish()