import argparse

import torch
from transformers import TrainingArguments, IntervalStrategy, EarlyStoppingCallback
from model.tokenizer import SmilesTokenizer
from model.antbrain import AntBrain
from model.trainer import CustomTrainer
from datasets import load_dataset
from model.utils import prepare_input_and_labels
import glob2

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='AntBrain training')
    parser.add_argument(
        '-t', '--train', default='./data/train_graph.jsonl.gz',
        type=str,
        help='Root directory with the training data')
    parser.add_argument(
        '-v', '--validation', default='./data/train_graph.jsonl.gz', type=str,
        help='Root directory with the validation data')
    parser.add_argument("--learning-rate", default=3e-05, type=float)
    parser.add_argument("--per-device-train-batch-size", default=8, type=int)
    parser.add_argument("--per-device-eval-batch-size", default=8, type=int)
    parser.add_argument("--weight-decay", default=0.01, type=float)
    parser.add_argument("--epochs", default=5, type=int)
    parser.add_argument("--save-total-limit", default=3, type=int)
    parser.add_argument("--saving_steps", default=1000, type=int)
    parser.add_argument("--evaluation_steps", default=1, type=int)
    parser.add_argument("--adam-eps", default=1e-08, type=float)
    parser.add_argument("--adam-betas", default=(0.9, 0.999), nargs="+", type=float)
    parser.add_argument("--warmup-updates", default=500, type=int)
    parser.add_argument("--warmup_steps", default=500, type=int)
    parser.add_argument("--max_steps", default=500, type=int)
    parser.add_argument("--patience", default=200, type=int)
    parser.add_argument("--num_workers", default=10, type=int)
    parser.add_argument("--logging-steps", default=1, type=int)
    parser.add_argument("--fp16", action='store_true')
    parser.add_argument("--deepspeed", default=None, type=str, help="Deep speed configuration file")
    parser.add_argument("--local_rank", type=int, default=-1)
    parser.add_argument('-m', '--model', type=str, default='bert-base-uncased', help='model name')
    parser.add_argument('--max_length', type=int, default=128, help='Max sequence length')
    parser.add_argument('--num_chunks', type=int, default=2, help='number of chunks per training step')
    parser.add_argument('--vocab_path', type=str, default="./data/vocab.txt", help='vocab file path.')
    parser.add_argument('--prediction_objective', type=str, default="lm", help='prediction objective: all, 3d, lm')
    parser.add_argument('--output_dir', type=str, default="./results", help='output dir where the models are saved')
    parser.add_argument('--checkpoint', type=str, default=None, help='Path to the check point.')
    parser.add_argument('--ignore_data_skip', type=str, default='yes',
                        help='whether skip checking data before training '
                             'from checkpoint')

    args = parser.parse_args()
    tokenizer = SmilesTokenizer(vocab_file=args.vocab_path)

    # DataLoaders
    file_lists = args.train.replace("'", "").split(",")
    train_files = []
    for file_list in file_lists:
        train_files += glob2.glob(file_list)
    train_files.sort()
    print("Training data")
    for file in train_files:
        print(file)
    validation_files = glob2.glob(args.validation.replace("'", ""))
    validation_files.sort()
    print("Validation data", validation_files)
    data_set = load_dataset('json',
                            data_files={'train': train_files,
                                        'val': validation_files},
                            streaming=True)
    data_set = data_set.map(lambda e: prepare_input_and_labels(tokenizer=tokenizer, input_sequences=e['smiles'],
                                                               max_length=args.max_length),
                            batched=True,
                            remove_columns=["atom", "smiles", "bond_edges", "bond_types"],
                            batch_size=100)

    data_set = data_set.with_format('torch')
    # checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
    model = AntBrain(model_name=args.model,
                     prediction_objective=args.prediction_objective,
                     tokenizer=tokenizer)
    # model.load_state_dict(checkpoint, strict=False)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        evaluation_strategy=IntervalStrategy.STEPS,
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        weight_decay=args.weight_decay,
        save_total_limit=args.save_total_limit,
        save_steps=args.saving_steps,
        eval_steps=args.evaluation_steps,
        num_train_epochs=args.epochs,
        logging_steps=args.logging_steps,
        fp16=args.fp16,
        dataloader_num_workers=args.num_workers,
        load_best_model_at_end=True,
        deepspeed=args.deepspeed,
        max_steps=args.max_steps,
        warmup_steps=args.warmup_steps,
        ignore_data_skip=args.ignore_data_skip == 'yes'
    )

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=data_set['train'],
        eval_dataset=data_set['val'],
        num_chunks=args.num_chunks,
        my_tokenizer=tokenizer,
        max_length=args.max_length,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]
    )
    if args.checkpoint:
        trainer.train(args.checkpoint)
    else:
        trainer.train()
    trainer.evaluate()
