import argparse
import glob
import logging
import os
from pytorch_lightning.callbacks import LearningRateMonitor
from pathlib import Path
from pytorch_lightning.loggers import WandbLogger

import pandas as pd
import pytorch_lightning as pl
import torch
from modeling.lightning_base import generic_train


from dataset import AutoregLMDataset

from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
from modeling.next_token_lm import NextTokenLMModule
from modeling.seq2seq_lm import Seq2SeqLMModule


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def main(args, model=None) -> Seq2SeqLMModule:
    Path(args.output_dir).mkdir(exist_ok=True, parents=True)
    # if len(os.listdir(args.output_dir)) > 3 and args.do_train:
    #     raise ValueError(
    #         "Output directory ({}) already exists and is not empty.".format(args.output_dir)
    #     )
    if model is None:
        if args.mode == "seq2seq":
            model: Seq2SeqLMModule = Seq2SeqLMModule(args)
        elif args.mode == "next-token-prediction":
            model: NextTokenLMModule = NextTokenLMModule(args)

    additional_tokens_list = []
    if args.mode == "next-token-prediction":
        additional_tokens_list.append(AutoregLMDataset.IO_SEP)

    if Path(f"{args.data_dir}/special_tokens.txt").is_file():

        additional_tokens_list.extend(
            pd.read_csv(
                f"{args.data_dir}/special_tokens.txt", names=["special_tokens"]
            )["special_tokens"].tolist()
        )

    if len(additional_tokens_list) > 0:
        # see https://stackoverflow.com/questions/67412925/what-is-the-difference-between-lentokenizer-and-tokenizer-vocab-size
        logging.info(
            f"Tokenizer size before additional tokens: {len(model.tokenizer.vocab)}"
        )
        logging.info(f"Adding additional tokens: {additional_tokens_list}")
        num_added_toks = model.tokenizer.add_tokens(additional_tokens_list)
        logging.info(f" {num_added_toks} special tokens added!")
        logging.info(
            f"Tokenizer size after additional tokens: {len(model.tokenizer.vocab)}"
        )
        # model.vocab_size += num_added_toks
        model.model.resize_token_embeddings(len(model.tokenizer.vocab))

    if args.mode == "next-token-prediction":
        model.tokenizer.io_sep_token = AutoregLMDataset.IO_SEP
        model.tokenizer.io_sep_token_id = model.tokenizer(AutoregLMDataset.IO_SEP)[
            "input_ids"
        ][0]
        logging.info(f"IO_SEP token is {AutoregLMDataset.IO_SEP}")
        logging.info(f"IO_SEP token id is {model.tokenizer.io_sep_token_id}")

    dataset = Path(args.data_dir).name

    logger = WandbLogger(name=model.output_dir.name, project=dataset)

    lr_monitor = LearningRateMonitor(logging_interval="step")
    trainer: pl.Trainer = generic_train(
        model,
        args,
        logging_callback=Seq2SeqLoggingCallback(),
        checkpoint_callback=get_checkpoint_callback(
            args.output_dir,
            model.hparams.val_metric,
            save_top_k=model.hparams.save_top_k,
        ),
        logger=logger,
        extra_callbacks=[lr_monitor],
        val_check_interval=args.val_check_interval,
        gradient_clip_val=args.gradient_clip_val,
        # track_grad_norm=2,
        gpus=args.gpus if torch.cuda.is_available() else 0,
        log_every_n_steps=50,
    )

    if not args.do_predict:
        return model

    model.hparams.test_checkpoint = ""
    checkpoints = list(
        sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))
    )
    if checkpoints:
        model.hparams.test_checkpoint = checkpoints[-1]
        trainer.resume_from_checkpoint = checkpoints[-1]

    trainer.test(model)
    return model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser.add_argument(
        "--mode", default="seq2seq", choices=["seq2seq", "next-token-prediction", "t5"]
    )

    parser.add_argument("--disable_wandb", action="store_true")
    
    # disable wandb

    partial_args = parser.parse_known_args()

    if partial_args[0].mode == "seq2seq":
        parser = Seq2SeqLMModule.add_model_specific_args(parser, os.getcwd())
    elif partial_args[0].mode == "next-token-prediction":
        parser = NextTokenLMModule.add_model_specific_args(parser, os.getcwd())
    else:
        raise ValueError(f"Unknown mode: {partial_args[0].mode}")
    args = parser.parse_args()
    if args.disable_wandb:
        import wandb
        wandb.init(mode="disabled")
    main(args)
