import logging
import os
import sys
import pdb
from pathlib import Path
import deepspeed
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, set_seed
import torch
from torch.utils.tensorboard import SummaryWriter
from transformers.integrations import TensorBoardCallback
from projects.T5_ANCEPROMPT.project_lib.arguments import DataArguments
from projects.T5_ANCEPROMPT.project_lib.arguments import DenseTrainingArguments as TrainingArguments
from projects.T5_ANCEPROMPT.project_lib.arguments import ModelArguments
from projects.T5_ANCEPROMPT.project_lib.dataset import DistillCollator, DistillDataset, EvalDistillDataset
from projects.T5_ANCEPROMPT.project_lib.modeling import DistillModel
from projects.T5_ANCEPROMPT.project_lib.trainer import DenseTrainer
from projects.T5_ANCEPROMPT.project_lib.trainer import GCDenseTrainer
from projects.T5_ANCEPROMPT.project_lib.utils import is_rank_0

logger = logging.getLogger(__name__)

def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
        model_args: ModelArguments
        data_args: DataArguments
        training_args: TrainingArguments

    if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, fp16 training: %s, bf16 training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
        training_args.bf16,
    )
    logger.info("Training/evaluation parameters %s", training_args)
    logger.info("MODEL parameters %s", model_args)

    set_seed(training_args.seed)

    num_labels = 1
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
    )
    model = DistillModel.build(
        model_args,
        data_args,
        training_args,
        config=config,
        cache_dir=model_args.cache_dir
    )

    train_dataset = DistillDataset(tokenizer, data_args, training_args, shuffle_seed=training_args.seed, cache_dir=data_args.data_cache_dir or model_args.cache_dir)
    eval_dataset = EvalDistillDataset(tokenizer, data_args, training_args, cache_dir=data_args.data_cache_dir or model_args.cache_dir) if data_args.eval_path is not None else None
    print(f"length of training dataset: {len(train_dataset)} examples")
    ### set up tensorboard logging
    if is_rank_0():
        tbWriter = SummaryWriter(training_args.logging_dir)
        tb_callback = TensorBoardCallback(tbWriter)

    trainer_cls = GCDenseTrainer if training_args.grad_cache else DenseTrainer
    trainer = trainer_cls(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=DistillCollator(
            tokenizer,
            max_p_len=data_args.p_max_len,
            max_q_len=data_args.q_max_len
        ),
        callbacks=[tb_callback] if is_rank_0() else []
    )
    train_dataset.trainer = trainer
    trainer.train()  # TODO: resume training
    trainer.save_model()
    if trainer.is_world_process_zero():
        tokenizer.save_pretrained(training_args.output_dir)


if __name__ == "__main__":
    main()
