import os
import sys

from projects.T5_ANCEPROMPT.project_lib.arguments import DataArguments
from projects.T5_ANCEPROMPT.project_lib.arguments import DenseEncodingArguments as EncodingArguments
from projects.T5_ANCEPROMPT.project_lib.arguments import ModelArguments
from projects.T5_ANCEPROMPT.project_lib.dataset import InferenceDataset
from projects.T5_ANCEPROMPT.project_lib.modeling import DenseModelForInference
from projects.T5_ANCEPROMPT.project_lib.retriever import Retriever
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser


def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, EncodingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        model_args, data_args, encoding_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, encoding_args = parser.parse_args_into_dataclasses()
        model_args: ModelArguments
        data_args: DataArguments
        encoding_args: EncodingArguments

    num_labels = 1
    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        # use_fast=False,
    )

    model = DenseModelForInference.build(
        model_name_or_path=model_args.model_name_or_path,
        model_args=model_args,
        config=config,
        cache_dir=model_args.cache_dir,
        use_ground=data_args.use_ground,
    )

    corpus_dataset = InferenceDataset.load(
        tokenizer=tokenizer,
        data_args=data_args,
        is_query=False,
        cache_dir=model_args.cache_dir
    )

    Retriever.build_embeddings(model,tokenizer, corpus_dataset, encoding_args)


if __name__ == '__main__':
    main()
