import sys, csv, json, random, argparse, logging
from pathlib import Path
from multiprocessing import Process, Pipe
from tqdm import tqdm
from transformers import AutoTokenizer
from dpr.options import setup_logger

from dpr.utils.sampler_utils import PyseriniSampler

csv.field_size_limit(sys.maxsize)
logger = logging.getLogger(__name__)
setup_logger(logger)
SUPPORT_FORMAT = ['.json', '.jsonl', '.csv', '.tsv']

def get_argument():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--type',
        default='train',
        type=str,
    )
    parser.add_argument(
        '--n_process',
        default=16,
        type=int
    )
    parser.add_argument(
        '--n_negative',
        default=16,
        type=int
    )
    parser.add_argument(
        '--context_file',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--answer_file',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--query_file',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--output',
        type=str,
        required=True,
    )
    parser.add_argument(
        '--bm25_topk',
        type=int,
    )
    parser.add_argument(
        '--score_file',
        type=str,
    )
    args = parser.parse_args()
    return args

def read_context(filepath : str):
    '''
    The input file specified in filepath must be a CSV file of [passage_id, passage_text] format.
    '''
    corpus_path = Path(filepath)
    assert corpus_path.is_file(), 'Given dir is not a file' 
    assert corpus_path.suffix in SUPPORT_FORMAT, '%s is not supported' % (corpus_path.suffix)
    
    passages = dict()
    with open(corpus_path,'r') as corpus_file:
        if corpus_path.suffix == '.csv' or corpus_path.suffix == '.tsv':
            corpus_csv = csv.reader(corpus_file,delimiter='\t')
            for pid, psg in corpus_csv:
                passages[pid] = psg
        else:
            passages = json.load(corpus_file)
    return passages

def read_query(filepath : str):
    '''
    The input file specified in filepath must be a CSV file of [query_id, query_text] format.
    '''
    query_path = Path(filepath)
    assert query_path.is_file(), 'Given dir is not a file' 
    assert query_path.suffix in SUPPORT_FORMAT, '%s is not supported' % (query_path.suffix)
    
    queries = dict()
    with open(query_path,'r') as query_file:
        if query_path.suffix == '.csv' or query_path.suffix == '.tsv':
            query_csv = csv.reader(query_file,delimiter='\t')
            for qid, que in query_csv:
                queries[qid] = que
        else:
            queries = json.load(query_file)
    return queries

def read_answers(filepath : str):
    '''
    The input file specified in filepath must be a CSV file of [qid, 0, pid, 0] format (following MSMARCO format).
    '''
    answer_path = Path(filepath)
    assert answer_path.is_file(), 'Given dir is not a file' 
    assert answer_path.suffix in SUPPORT_FORMAT, '%s is not supported' % (answer_path.suffix)
    
    with open(answer_path,'r') as answer_file:
        if answer_path.suffix == '.csv' or answer_path.suffix == '.tsv':
            answer_csv = csv.reader(answer_file,delimiter='\t')
            answers = [[qid, pid] for qid, _, pid, _ in answer_csv]
        else:
            answers = json.load(answer_file)
            answers = list(answers.items())
    return answers

def split_qa_dataset(qrels, psgs, qrys, negatives, neg_pool : dict, conn):
    entries = list()
    for qid, pid in tqdm(qrels):
        entry = {
            'question' : '',
            'answers' : list(),
            'positive_ctxs' : list(),
            'negative_ctxs' : list(),
            'hard_negative_ctxs' : list()
        }
        entry['question'] = qrys[qid]
        entry['positive_ctxs'].append({
            'title' : '',
            'text' : psgs[pid],
            'passage_id' : pid
        })

        if neg_pool:
            sample_pool = neg_pool[qid]
            #negatives = min([len(negs) in id, negs in neg_pool.items()])
        else:
            sample_pool = list(psgs.keys())
        if pid in sample_pool:
            sample_pool.remove(pid)
        sample_pids = random.sample(sample_pool, k=negatives-1)
        for n_pid in sample_pids[:negatives//2]:
            entry['negative_ctxs'].append({
                'title' : '',
                'text' : psgs[n_pid],
                'passage_id' : n_pid
            })
        for hn_pid in sample_pids[negatives//2:]:
            entry['hard_negative_ctxs'].append({
                'title' : '',
                'text' : psgs[hn_pid],
                'passage_id' : hn_pid
            })
        entries.append(entry)
    conn.send(entries)
    conn.close()
    
def split_ctx_dataset(pids, psgs, conn):
    entries = [[pid, psgs[pid], ''] for pid in pids]
    conn.send(entries)
    conn.close()

def split_query_dataset(qids, queries, conn):
    entries = list()
    for qid in tqdm(qids):
        que = queries[qid]
        entry = [que, qid, ['']]
        entries.append(entry)
    conn.send(entries)
    conn.close()
    
def get_dpr_train(args):
    passages = read_context(args.context_file)
    queries = read_query(args.query_file)
    answers = read_answers(args.answer_file)
    logger.info('# samples to write : %d' % len(answers))
    
    n_negs = args.n_negative
    negative_pool = None
    if args.bm25_topk:
        answers = list()
        n_negs = sys.maxsize

        bm25_sampler = PyseriniSampler.from_raw_corpus(args.context_file)
        bm25_top4050 = dict()
        for qid in queries:
            bm25_tops = bm25_sampler.get_topk(queries[qid], k=50, id_only=True)
            if len(bm25_tops) < 50:
                continue
            bm25_top10 = [[qid, did] for did in bm25_tops[:args.bm25_topk]]
            bm25_top4050[qid] = bm25_tops[50-args.bm25_topk:]
            n_negs = min(n_negs, len(bm25_tops[50-args.bm25_topk:]))
            answers.extend(bm25_top10)
        negative_pool = bm25_top4050
    elif args.score_file:
        with open(args.score_file, 'r') as f:
            neg_scores = json.load(f)

    processes = []
    pipes = []
    train_samples = list()
    for proc in range(args.n_process):
        logger.info('Starting process %d' % proc)
        start_index = int(proc * len(answers)/args.n_process)
        end_index = int(min(len(answers), (proc+1) * len(answers)/args.n_process))
        parent_conn, child_conn=Pipe()
        p = Process(
            target=split_qa_dataset,
            args = (
                answers[start_index:end_index],
                passages,
                queries,
                n_negs,
                negative_pool,
                child_conn,
            )
        )
        pipes.append(parent_conn)
        processes.append(p)
        p.start()
    for p, conn in zip(processes, pipes):
        split_samples = conn.recv()
        assert type(split_samples) == list
        train_samples.extend(split_samples)
        p.join()
        conn.close()
    
    f = open(args.output,'w',encoding='utf-8')
    f.close()

    with open(args.output,'a',encoding='utf-8') as out_f:
        json.dump(train_samples, out_f)
    logger.info('# samples written : %d' % len(train_samples))

def get_dpr_test_ctx(args):
    passages = read_context(args.context_file)
    pids = list(passages.keys())
    logger.info('# samples to write : %d' % len(pids))

    processes = []
    pipes = []
    test_ctxs = list()
    for proc in range(args.n_process):
        logger.info('Starting process %d' % proc)
        start_index = int(proc * len(passages)/args.n_process)
        end_index = int(min(len(passages), (proc+1) * len(passages)/args.n_process))
        parent_conn, child_conn=Pipe()
        p = Process(
            target=split_ctx_dataset,
            args = (
                pids[start_index:end_index],
                passages,
                child_conn
            )
        )
        pipes.append(parent_conn)
        processes.append(p)
        p.start()
    for p, conn in zip(processes, pipes):
        split_samples = conn.recv()
        assert type(split_samples) == list
        test_ctxs.extend(split_samples)
        p.join()
        conn.close()
    
    f = open(args.output,'w',encoding='utf-8')
    f.close()

    with open(args.output,'w') as out_file:
        write_csv = csv.writer(out_file, delimiter='\t')
        for line in test_ctxs:
            write_csv.writerow(line)
    logger.info('# samples written : %d' % len(test_ctxs))
    
def get_dpr_test_queries(args):
    queries = read_query(args.query_file)
    qids = list(queries.keys())
    logger.info('# samples to write : %d' % len(qids))
    
    processes = []
    pipes = []
    test_queries = list()
    for proc in range(args.n_process):
        logger.info('Starting process %d' % proc)
        start_index = int(proc * len(qids)/args.n_process)
        end_index = int(min(len(qids), (proc+1) * len(qids)/args.n_process))
        parent_conn, child_conn=Pipe()
        p = Process(
            target=split_query_dataset,
            args = (
                qids[start_index:end_index],
                queries,
                child_conn
            )
        )
        pipes.append(parent_conn)
        processes.append(p)
        p.start()
    for p, conn in zip(processes, pipes):
        split_samples = conn.recv()
        assert type(split_samples) == list
        test_queries.extend(split_samples)
        p.join()
        conn.close()
    
    f = open(args.output,'w',encoding='utf-8')
    f.close()

    with open(args.output,'w') as out_file:
        write_csv = csv.writer(out_file, delimiter='\t')
        for line in test_queries:
            write_csv.writerow(line)
    logger.info('# samples written : %d' % len(test_queries))



def main(args):
    format_type = ['train','context','query']
    assert args.type in format_type
    if args.type == 'train':
        get_dpr_train(args)
    elif args.type == 'context':
        get_dpr_test_ctx(args)
    elif args.type == 'query':
        get_dpr_test_queries(args)

if __name__ == "__main__":
    args = get_argument()
    main(args)
    