import sys, csv, json, random, argparse, logging
from pathlib import Path
from multiprocessing import Process, Pipe
from tqdm import tqdm
from transformers import AutoTokenizer
#from dpr.options import setup_logger

csv.field_size_limit(sys.maxsize)
logger = logging.getLogger(__name__)
#setup_logger(logger)
SUPPORT_FORMAT = ['.json', '.jsonl', '.csv', '.tsv']

def get_argument():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--type',
        default='train',
        type=str,
    )
    parser.add_argument(
        '--n_process',
        default=16,
        type=int
    )
    parser.add_argument(
        '--n_negative',
        type=int
    )
    parser.add_argument(
        '--negative_file',
        type=str,
    )
    parser.add_argument(
        '--context_file',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--answer_file',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--query_file',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--output',
        type=str,
        required=True,
    )
    args = parser.parse_args()
    return args

def read_context(filepath : str):
    corpus_path = Path(filepath)
    assert corpus_path.is_file() and corpus_path.suffix in SUPPORT_FORMAT, corpus_path.suffix
    
    passages = dict()
    with open(corpus_path,'r') as corpus_file:
        if corpus_path.suffix == '.csv' or corpus_path.suffix == '.tsv':
            corpus_csv = csv.reader(corpus_file,delimiter='\t')
            for pid, psg in corpus_csv:
                #print(line)
                passages[pid] = psg
        else:
            passages = json.load(corpus_file)
    return passages

def read_query(filepath : str):
    query_path = Path(filepath)
    assert query_path.is_file() and query_path.suffix in SUPPORT_FORMAT, query_path.suffix
    
    queries = dict()
    with open(query_path,'r') as query_file:
        if query_path.suffix == '.csv' or query_path.suffix == '.tsv':
            query_csv = csv.reader(query_file,delimiter='\t')
            for qid, que in query_csv:
                queries[qid] = que
        else:
            queries = json.load(query_file)
    return queries

def read_answers(filepath : str):
    answer_path = Path(filepath)
    assert answer_path.is_file() and answer_path.suffix in SUPPORT_FORMAT, answer_path.suffix
    
    with open(answer_path,'r') as answer_file:
        if answer_path.suffix == '.csv' or answer_path.suffix == '.tsv':
            answer_csv = csv.reader(answer_file,delimiter='\t')
            answers = [[qid, pid] for qid, _, pid, _ in answer_csv]
        else:
            answers = list()
            ans_dict = json.load(answer_file)
            for qid, pids in ans_dict.items():
                answers.append([qid, random.choice(pids)])
    return answers

def process_passage(args):
    pass

def main(args):
    process_passage(args)

if __name__ == "__main__":
    args = get_argument()
    main(args)
    