
import logging
import gc
import glob
import os
import sys
import time
import numpy as np
import pandas as pd

import torch
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoTokenizer,
    PreTrainedTokenizer,
    HfArgumentParser,
    TrainerCallback,
)

from pecos.utils import smat_util
from sup_con_xmc.arguments import (
    ModelArguments,
    TrainingDataArguments,
    MyTrainingArguments as TrainingArguments
)
from sup_con_xmc.base_utils import setup_hf_logging_and_seed
from sup_con_xmc.data import (
    TrainPreProcessor,
    TrainDataset,
    TrainCollator
)
from sup_con_xmc.models import DenseModel
from sup_con_xmc.trainer import TevatronTrainer as Trainer
from sup_con_xmc.searcher import SearcherQ2Z


logger = logging.getLogger(__name__)


class EarlyStoppingCallback(TrainerCallback):
    """
    A [`TrainerCallback`] that handles early stopping.

    Args:
        early_stopping_patience (`int`):
            Use with `metric_for_best_model` to stop training when the specified metric worsens for
            `early_stopping_patience` evaluation calls.
        early_stopping_threshold(`float`, *optional*):
            Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
            specified metric must improve to satisfy early stopping conditions. `

    This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
    in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
    early stopping will not occur until the next save step.
    """

    def __init__(self, early_stopping_steps: int = 50):
        self.early_stopping_steps = early_stopping_steps
        self.early_stopping_counter = 0

    def on_step_end(self, args, state, control, **kwargs):
        self.early_stopping_counter += 1
        if self.early_stopping_counter >= self.early_stopping_steps:
            control.should_training_stop = True


def write_df_to_parquet_files(df_inp, output_dir, num_partitions=256):
    os.makedirs(output_dir, exist_ok=True)
    num_rows = len(df_inp)
    assert num_rows >= num_partitions
    for part_id, indices in enumerate(np.array_split(range(num_rows), num_partitions)):
        cur_df = df_inp.loc[indices]
        cur_path = os.path.join(output_dir, f"part-{part_id:05d}.parquet")
        cur_df.to_parquet(cur_path)


def get_train_datasets(
    tokenizer: PreTrainedTokenizer,
    data_args: TrainingDataArguments,
    training_args: TrainingArguments,
    overridden_trn_folder=None,
):
    # every worker should all load Yt, as its not memory-mapped
    local_rank = training_args.local_rank
    if local_rank > 0:
        logger.warning(f"LOCAL_RANK {local_rank} waiting main process to perform data preprocessing")
        torch.distributed.barrier()

    # Load input text dataset w/ CSR labels
    if overridden_trn_folder:
        trn_paths = sorted(glob.glob(f"{overridden_trn_folder}/*.parquet"))
    else:
        trn_paths = sorted(glob.glob(f"{data_args.trn_folder}/*.parquet"))
    required_cols = TrainDataset.get_required_cols()
    inp_dataset = load_dataset("parquet", data_files={"train": trn_paths}, split="train")
    inp_dataset = inp_dataset.map(
        TrainPreProcessor(tokenizer, data_args.q_max_len),
        batched=True,
        num_proc=data_args.dataset_proc_num,
        remove_columns=[col for col in inp_dataset.column_names if col not in required_cols],
        desc="Running tokenizer on query text of trn_dataset",
    )
   
    # Load label text dataset
    lbl_paths = sorted(glob.glob(f"{data_args.lbl_folder}/*.parquet"))
    lbl_dataset = load_dataset("parquet", data_files={"train": lbl_paths}, split="train")
    lbl_dataset = lbl_dataset.map(
        TrainPreProcessor(tokenizer, data_args.p_max_len),
        batched=True,
        num_proc=data_args.dataset_proc_num,
        remove_columns=[col for col in lbl_dataset.column_names if col != 'pos_trn_ids'],
        desc="Running tokenizer on label text of lbl_dataset",
    )

    train_dataset = TrainDataset(data_args, inp_dataset, lbl_dataset)
    if local_rank == 0:
        logger.warning(f"LOCAL_RANK {local_rank} loading results from main process")
        torch.distributed.barrier()
    return train_dataset


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, TrainingDataArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    # Set logging
    setup_hf_logging_and_seed(model_args, data_args, training_args)

    # Set Dual-encoder Model
    num_labels = 1
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir
    )
    model = DenseModel.build(
        model_args,
        training_args,
        config=config,
        cache_dir=model_args.cache_dir,
    ) 

    # Constants
    hnm_steps = training_args.hnm_steps
    max_steps = training_args.max_steps
    run_time_dict = {}

    # Stage 1: Training without Hard-negative Mining (HNM)
    t0 = time.time() 
    stage1_max_steps = min(hnm_steps, max_steps)
    logger.info(f" NO HNM training | training_steps [{0:6d}/{stage1_max_steps:6d}]")
    Yt = smat_util.load_matrix(data_args.y_npz_path).astype(np.float32)
    train_dataset = get_train_datasets(tokenizer, data_args, training_args)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=TrainCollator(tokenizer, data_args.q_max_len, data_args.p_max_len),
        callbacks=[EarlyStoppingCallback(early_stopping_steps=stage1_max_steps)],
    )
    train_dataset.trainer = trainer
    trainer.train(resume_from_checkpoint=False)
    if training_args.local_rank == 0:
        trainer.save_model()
        tokenizer.save_pretrained(training_args.output_dir)
    torch.distributed.barrier()
    del train_dataset, trainer; gc.collect();
    run_time_dict[stage1_max_steps] = time.time() - t0

    # Stage 2: Training with Hard-negative Mining (HNM)
    # this is achieved by using the "resume_from_checkpoint" in HF Trainer,
    # which requires hnm_steps >= save_steps and hmn_steps % save_steps == 0
    if (
        training_args.hnm_steps < training_args.save_steps
        or training_args.hnm_steps % training_args.save_steps != 0
    ):
        raise ValueError(f"We require that hnm_steps >= save_steps AND hnm_steps % save_steps == 0!")

    for idx, cur_steps in enumerate(range(hnm_steps, max_steps, hnm_steps)):
        overridden_trn_folder = f"{training_args.output_dir}/trn_hn_{cur_steps:06d}"
        nxt_steps = cur_steps + hnm_steps
        logger.info(f" HNM-stage {idx} | training_steps [{cur_steps:6d}/{nxt_steps:6d}]")

        # Step-1: Build Hard-negative Mining dataset
        searcher = SearcherQ2Z(training_args.output_dir, training_args)
        if training_args.hnm_type == "q2xz":
            trn_emb = searcher.encode_from_files(data_args.trn_folder, data_args.inp_key_col, data_args.q_max_len, training_args.per_device_eval_batch_size, encode_is_qry=True)
            lbl_emb = searcher.encode_from_files(data_args.lbl_folder, data_args.lbl_key_col, data_args.p_max_len, training_args.per_device_eval_batch_size, encode_is_qry=False) 
        elif training_args.hnm_type == "q2x":
            trn_emb = searcher.encode_from_files(data_args.trn_folder, data_args.inp_key_col, data_args.q_max_len, training_args.per_device_eval_batch_size, encode_is_qry=True)
        elif training_args.hnm_type == "q2z":
            lbl_emb = searcher.encode_from_files(data_args.lbl_folder, data_args.lbl_key_col, data_args.p_max_len, training_args.per_device_eval_batch_size, encode_is_qry=False) 
        else:
            raise ValueError(f"hnm_type={training_args.hnm_type} is NOT VALID!")

        if training_args.local_rank == 0:
            if training_args.hnm_type == "q2xz":
                all_emb = np.concatenate([trn_emb, lbl_emb], axis=0)
            elif training_args.hnm_type == "q2x":
                all_emb = trn_emb
            elif training_args.hnm_type == "q2z":
                all_emb = lbl_emb
        else:
            all_emb = None
        torch.distributed.barrier()
        searcher.train_from_files(all_emb, fast_train=True)
        Yp = searcher.predict_by_q2xz(
            Yt,
            data_args.trn_folder,
            data_args.inp_key_col,
            data_args.q_max_len,
            training_args.per_device_eval_batch_size,
            encode_is_qry=True,
            topk=training_args.hnm_topk,
            inference_method=training_args.hnm_type,
        )
        if training_args.local_rank == 0:
            Y_hn = Yt - smat_util.binarized(Yp)
            Y_hn.data[Y_hn.data > 0] = 0.0 # ignore false negative
            Y_hn.eliminate_zeros()
            df_trn_hn = pd.read_parquet(data_args.trn_folder)
            df_trn_hn["neg_labels"] = [
                Y_hn.indices[Y_hn.indptr[i] : Y_hn.indptr[i + 1]] for i in range(Y_hn.shape[0])
            ]
            logger.info(f"Avg hard negative for each X: {Y_hn.data.shape[0]/Y_hn.shape[0]}")
            logger.info(f"Writing HNM training dataset into disk {overridden_trn_folder}")
            write_df_to_parquet_files(df_trn_hn, overridden_trn_folder)
        torch.distributed.barrier()
        del searcher, Yp; gc.collect();

        # Step-2: Train on HNM training dataset via resuming from the prev checkpoint
        train_dataset = get_train_datasets(
            tokenizer, data_args, training_args,
            overridden_trn_folder=overridden_trn_folder,
        )
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=TrainCollator(tokenizer, data_args.q_max_len, data_args.p_max_len),
            callbacks=[EarlyStoppingCallback(early_stopping_steps=hnm_steps)],
        )
        train_dataset.trainer = trainer
        trainer.train(resume_from_checkpoint=True)
        run_time_dict[nxt_steps] = time.time() - t0

        # Step-3: save current model via main node (local_rank == 0)
        if training_args.local_rank == 0:
            trainer.save_model()
            tokenizer.save_pretrained(training_args.output_dir)
        torch.distributed.barrier()
        del train_dataset, trainer
        gc.collect()
    # end training loop of HNM
    logger.info(f"Finished Training! Runtime {run_time_dict}")
    return

if __name__ == "__main__":
    main()
