
import logging
import glob
import os
import sys
import time
import numpy as np
from tqdm import tqdm
from typing import Dict, Optional
from contextlib import nullcontext

import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoConfig,
    AutoTokenizer,
    PreTrainedTokenizer,
    HfArgumentParser
)

from pecos.utils import smat_util
from pecos.ann.hnsw import HNSW
from scipy.special import softmax
from pecos.core import clib
import scipy.sparse as smat
from scipy.sparse import diags
from sklearn.preprocessing import normalize
from sup_con_xmc.arguments import (
    ModelArguments,
    SearcherDataArguments,
    MyTrainingArguments as TrainingArguments
)
from sup_con_xmc.data import (
    EncodePreProcessor,
    EncodeDataset,
    EncodeCollator
)
from sup_con_xmc.models import EncoderOutput, DenseModel
from sup_con_xmc.base_utils import setup_hf_logging_and_seed


logger = logging.getLogger(__name__)


class SearcherQ2Z():
    def __init__(self, model_path, training_args):
        if not dist.is_initialized():
            raise ValueError('DDP has not been initialized for representation all gather.')
        self.training_args = training_args
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = DenseModel.load(model_name_or_path=model_path)
        self.model = self.model.to(training_args.device)
        #TODO think about better way to get emb_dim from model
        self.emb_dim = self.model.lm_q.config.hidden_size
        # for DDP inference
        self.model.process_rank = dist.get_rank()
        self.model.world_size = dist.get_world_size()
    
    def _get_dataset(self, inp_folder, key_col, max_length=32, threads=64):
        if self.training_args.local_rank > 0:
            logger.warning(f"LOCAL_RANK {self.training_args.local_rank} waiting main process to perform data preprocessing")
            torch.distributed.barrier()
        
        paths = sorted(glob.glob(f"{inp_folder}/*.parquet"))
        dataset = load_dataset("parquet", data_files={"test": paths}, split="test")
        dataset = dataset.map(
            EncodePreProcessor(self.tokenizer, key_col, max_length),
            batched=True,
            num_proc=threads,
            remove_columns=[col for col in dataset.column_names],
            desc="Running tokenizer for dataset",
        )
        dataset = EncodeDataset(key_col, dataset)

        if self.training_args.local_rank == 0:
            logger.warning(f"LOCAL_RANK {self.training_args.local_rank} loading results from main process")
            torch.distributed.barrier()

        return dataset

    def _get_dataloader(self, eval_dataset, per_device_eval_batch_size=128):
        dist_sampler = DistributedSampler(eval_dataset, shuffle=False, drop_last=False) 
        eval_collator = EncodeCollator(self.tokenizer)
        data_loader = DataLoader(
            eval_dataset,
            batch_size=per_device_eval_batch_size,
            collate_fn=eval_collator,
            sampler=dist_sampler,
            num_workers=self.training_args.dataloader_num_workers,
        )
        return data_loader
        
    def encode_from_files(self, inp_folder_or_emb, key_col, max_length, bsz_per_gpu, encode_is_qry=True):
        if not isinstance(inp_folder_or_emb, str):
            return inp_folder_or_emb
        logger.info(f"Encode dataset from {inp_folder_or_emb} | key_col {key_col}")
        dataset = self._get_dataset(inp_folder_or_emb, key_col, max_length)
        data_loader = self._get_dataloader(dataset, bsz_per_gpu)
        num_items = len(dataset)
        logger.info(f"len(dataset) = {num_items} | encode_is_qry = {encode_is_qry}")

        if self.training_args.local_rank == 0:
            embeddings = np.zeros((num_items, self.emb_dim), dtype=np.float32)
        else:
            embeddings = None

        self.model.eval()
        for (key_ids, batch) in tqdm(data_loader):
            with torch.cuda.amp.autocast() if self.training_args.fp16 else nullcontext():
                with torch.no_grad():
                    key_ids = key_ids.to(self.training_args.device)
                    for k, v in batch.items():
                        batch[k] = v.to(self.training_args.device)
                    if encode_is_qry:
                        model_output: EncoderOutput = self.model(query=batch).q_reps
                    else:
                        model_output: EncoderOutput = self.model(passage=batch).p_reps
            ## dist gather on main node
            agg_key_arr = self.model._dist_gather_tensor(key_ids).cpu().detach().numpy()
            agg_val_arr = self.model._dist_gather_tensor(model_output).cpu().detach().numpy()
            if self.training_args.local_rank == 0:
                embeddings[agg_key_arr, :] = agg_val_arr
        # end of for loop
        if self.training_args.local_rank == 0:
            return embeddings

    def train_from_files(self, inp_folder_or_emb, key_col=None, max_length=None, bsz_per_gpu=None, encode_is_qry=True, fast_train=True):
        self.X_trn = self.encode_from_files(inp_folder_or_emb, key_col, max_length, bsz_per_gpu, encode_is_qry)
        if self.training_args.local_rank == 0:
            t0 = time.time()
            if fast_train:
                train_params = HNSW.TrainParams(M=32, efC=300, threads=-1)
            else:
                train_params = HNSW.TrainParams(M=64, efC=500, threads=-1)
            indexer = HNSW.train(self.X_trn, train_params=train_params) 
            logger.info(f"Trained index_X in {time.time()-t0:8.2f} seconds")
            self.indexer = indexer
        else:
            self.indexer = None
    
    def predict_from_files(self, inp_folder_or_emb, key_col, max_length, bsz_per_gpu, encode_is_qry=True, topk=100):
        self.X_tst = self.encode_from_files(inp_folder_or_emb, key_col, max_length, bsz_per_gpu, encode_is_qry)
        
        if self.training_args.local_rank == 0:
            t0 = time.time()
            pred_params = HNSW.PredParams(efS=300, topk=topk)
            searchers = self.indexer.searchers_create(num_searcher=96)
            Yp = self.indexer.predict(self.X_tst, pred_params=pred_params, searchers=searchers)
            Yp.data = - Yp.data 
            return Yp
        else:
            return None
        
    def predict_by_q2xz(self, Yt, inp_folder_or_emb, key_col, max_length, bsz_per_gpu, encode_is_qry=True, topk=100, inference_method='q2xz', lamb=0.5, indexer_dir=None):
        t0 = time.time()
        if indexer_dir:    
            Y_raw_p_path = os.path.join(indexer_dir, f"Yrp_{topk}.npz")
            if Y_raw_p_path and os.path.exists(Y_raw_p_path):
                if self.training_args.local_rank == 0:
                    Yp = smat_util.load_matrix(Y_raw_p_path).astype(np.float32)
                else:
                    return None
            else:
                Yp = self.predict_from_files(
                    inp_folder_or_emb,
                    key_col,
                    max_length,
                    bsz_per_gpu,
                    encode_is_qry,
                    topk=topk,
                )
                if self.training_args.local_rank == 0:
                    smat_util.save_matrix(Y_raw_p_path, Yp)
        else:
            Yp = self.predict_from_files(
                inp_folder_or_emb,
                key_col,
                max_length,
                bsz_per_gpu,
                encode_is_qry,
                topk=topk,
            )
        if self.training_args.local_rank != 0:
            return None
        if inference_method == 'q2z':
            return Yp
        elif inference_method == 'q2xz':
            Yt = Yt * lamb
            z_weight = 1 - lamb
            Z_csr_mat = diags(np.ones(Yt.shape[1],dtype=np.float32), 0, format="csr") * z_weight
            XZ_csr_mat = smat.vstack([Yt, Z_csr_mat])
        elif inference_method == 'q2x':
            XZ_csr_mat = Yt

        Yp.data = softmax(Yp.data.reshape(-1,topk)/self.training_args.temperature,axis=1).reshape(-1)

        Yp_x = Yp[:,:Yt.shape[0]]
        Yp_z = Yp[:,Yt.shape[0]:]

        logger.info(f"avg X prediction num: {len(Yp_x.data)/Yp_x.shape[0]}, score mean:{Yp_x.data.mean()}")
        logger.info(f"avg Z prediction num: {len(Yp_z.data)/Yp_z.shape[0]}, score mean:{Yp_z.data.mean()}")

        Yp_q2XZ = clib.sparse_matmul(Yp, XZ_csr_mat)
        Yp_q2XZ = smat_util.sorted_csr(Yp_q2XZ, only_topk=topk)
        logger.info(f" Predicted Indexer in {time.time()-t0:8.2f} seconds")
        return Yp_q2XZ
    
    def save_indexer(self, model_folder):
        """ Save an XZ-Index to file
        Args:
            model_folder (str): model directory to which the model is saved.
        """
        os.makedirs(model_folder, exist_ok=True)
        self.indexer.save(f"{model_folder}")

    def load_indexer(self, model_folder, lazy_load=False):
        self.indexer = HNSW.load(f"{model_folder}", lazy_load=lazy_load)


def get_pifa_z(X_trn, Y_trn, temp):
    logger.info(f"Getting Z_emb by PIFA")
    X_trn = smat.csr_matrix(X_trn)
    Y_trn_t = normalize(Y_trn.transpose().tocsr(), norm='l1')
    Z_trn = clib.sparse_matmul(Y_trn_t, X_trn).toarray()
    if temp < 1.0:
        return normalize(Z_trn, norm='l2')
    else:
        return Z_trn


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, SearcherDataArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Set logging
    setup_hf_logging_and_seed(model_args, data_args, training_args)
    
    # Load Dual-encoder Model
    searcher = SearcherQ2Z(model_args.model_name_or_path, training_args)
     
    logger.info("Building Searcher index via PECOS-HNSW")
    indexer_dir = f"{training_args.output_dir}"
    if not os.path.exists(f"{indexer_dir}/param.json"):
        trn_emb = searcher.encode_from_files(
            data_args.trn_folder,
            data_args.inp_key_col,
            data_args.text_max_len,
            training_args.per_device_eval_batch_size,
            encode_is_qry=True,
        )
        
        # if label folder is not provided, construct lbl_emb by PIFA
        if data_args.lbl_folder: 
            lbl_emb = searcher.encode_from_files(
                data_args.lbl_folder,
                data_args.lbl_key_col,
                data_args.text_max_len,
                training_args.per_device_eval_batch_size,
                encode_is_qry=False, 
            )
        elif training_args.local_rank == 0:
            Yt = smat_util.load_matrix(data_args.y_npz_path).astype(np.float32)
            lbl_emb = get_pifa_z(trn_emb, Yt, training_args.temperature)
        if training_args.local_rank != 0:
            logger.warning("waiting main thread")
        elif data_args.inference_method == 'q2xz':
            logger.info(f"Start training indexer q2xZ")
            qz_emb = np.concatenate([trn_emb, lbl_emb], axis=0)
            searcher.train_from_files(qz_emb, fast_train=False)
        elif data_args.inference_method == 'q2x':
            logger.info(f"Start training indexer q2x")
            searcher.train_from_files(trn_emb, fast_train=False)
        elif data_args.inference_method == 'q2z':
            logger.info(f"Start training indexer Q2Z")
            searcher.train_from_files(lbl_emb, fast_train=False)
        if training_args.local_rank == 0:
            searcher.save_indexer(indexer_dir)
    else:
        if training_args.local_rank == 0:
            searcher.load_indexer(indexer_dir)
    logger.info(f"Build Searcher prediction using {data_args.inference_method}, topk={data_args.inference_topk}")
    Yt = smat_util.load_matrix(data_args.y_npz_path).astype(np.float32)
    Yp = searcher.predict_by_q2xz(
        Yt,
        data_args.tst_folder,
        data_args.inp_key_col,
        data_args.text_max_len,
        training_args.per_device_eval_batch_size,
        encode_is_qry=True,
        topk=data_args.inference_topk,
        inference_method=data_args.inference_method,
        lamb=data_args.lamb,
        indexer_dir=indexer_dir,
    )
    if training_args.local_rank == 0:
        smat_util.save_matrix(f"{training_args.output_dir}/P.k-{data_args.inference_topk}.npz", Yp)


if __name__ == "__main__":
    main()
