import logging
import os
import pickle
import sys
from contextlib import nullcontext

import numpy as np
from tqdm import tqdm

import json
import torch

from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer
from transformers import (
    HfArgumentParser,
)

import trainers
import networks
import dataloaders

from arguments import ModelArguments, DataArguments, \
    DenseTrainingArguments as TrainingArguments

from transformers.trainer_utils import is_main_process

logger = logging.getLogger(__name__)


SPECB_QUE_BOS = "["
SPECB_QUE_EOS = "]"

SPECB_DOC_BOS = "{"
SPECB_DOC_EOS = "}"


def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
        model_args: ModelArguments
        data_args: DataArguments
        training_args: TrainingArguments

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )

    num_labels = 1
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=False,
    )
    
    ## Model
    model = networks.get_network(
        model_args,
        data_args,
        training_args,
        config=config,
        tokenizer=tokenizer,
        cache_dir=model_args.cache_dir,
        do_train=False,
    )
    
    print(torch.cuda.device_count())

    ## load dataset
    qrels, queries, corpus = dataloaders.get_beir_dataset(data_args=data_args, split_name="test")

    assert "gpt" in model_args.model_name_or_path 

    tokenizer.bos_token_q = tokenizer.encode(SPECB_QUE_BOS)
    tokenizer.eos_token_q = tokenizer.encode(SPECB_QUE_EOS)
    tokenizer.bos_token_d = tokenizer.encode(SPECB_DOC_BOS)
    tokenizer.eos_token_d = tokenizer.encode(SPECB_DOC_EOS)

    QryCollator = dataloaders.BEIRQryBracketCollator(tokenizer, max_length=data_args.q_max_len, padding='max_length')
    PsgCollator = dataloaders.BEIRPsgBracketCollator(tokenizer, max_length=data_args.p_max_len, padding='max_length')

    ## ***********************************************************************
    ## encode queries
    query_ids = sorted(list(queries.keys()))
    qry_dataset = dataloaders.BEIRDataset(
        text_lst=[queries[qid] for qid in query_ids],
    )
    qry_outs = trainers.get_trainer(
                args=training_args,
                model=model,
                data_collator=QryCollator,
                train_dataset=None,
                eval_dataset=None,
            ).predict(qry_dataset)
    
    if is_main_process(training_args.local_rank):
        with open(os.path.join(training_args.output_dir, "query", "qry.pt"), 'wb') as f:
            pickle.dump((qry_outs.predictions, query_ids), f)

    torch.cuda.empty_cache()
    
    ## ***********************************************************************
    ## encode corpus
    seg_i = 0
    corpus_ids = np.array(sorted(corpus, key=lambda k: len(corpus[k].split()), reverse=True))
    for doc_ids in tqdm(np.array_split(corpus_ids, data_args.sub_split_num), 
                        disable=not is_main_process(training_args.local_rank), desc="Split corpus encoding"):
        
        doc_dataset = dataloaders.BEIRDataset(
            text_lst=[corpus[did] for did in doc_ids],
        )

        doc_outs = trainers.get_trainer(
                    args=training_args,
                    model=model,
                    data_collator=PsgCollator,
                    train_dataset=None,
                    eval_dataset=None,
                ).predict(doc_dataset)

        if is_main_process(training_args.local_rank):
            with open(os.path.join(training_args.output_dir, "corpus", f'split{seg_i:02d}.pt'), 'wb') as f:
                pickle.dump((doc_outs.predictions, doc_ids), f)
                seg_i += 1
            # torch.distributed.barrier()

            
            
if __name__ == "__main__":
    main()