import jsonlines
import csv
import pandas as pd
import argparse
import os
import re
from beir.datasets.data_loader import GenericDataLoader
import pdb

parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', required=True, default='', help='dataset name from BEIR')
parser.add_argument('--save_dir', required=True, default='', help='folder path to save dataset')
args = parser.parse_args()
dataset_name = args.dataset_name
save_dir=args.save_dir

if dataset_name!='cqadupstack':

    # Corpus
    doc_count=0
    with open(save_dir+dataset_name+'/corpus.jsonl','r', encoding='utf-8') as fin:
        with open(save_dir+dataset_name+'/corpus.tsv', 'w', newline='') as fout:
            tsv_w = csv.writer(fout, delimiter='\t')
            for item in jsonlines.Reader(fin): 
                _id = item['_id'] 
                title = item['title']
                text = item['text']
                if dataset_name=='robust04':
                    text=re.sub(r"[^A-Za-z0-9=(),!?\'\`]"," ",text)
                    text=" ".join(text.split())
                tsv_w.writerow([_id, title, text])
                doc_count+=1

    # Test Qrels TSV
    test_q_ids=[]
    if dataset_name=='robust04':
        test_df=pd.read_csv(save_dir+dataset_name+'/qrels'+'/test.tsv', sep='\t',names=['query-id','corpus-id','score'])
    else:
        test_df=pd.read_csv(save_dir+dataset_name+'/qrels'+'/test.tsv', sep='\t', header=0)
    with open(save_dir+dataset_name+'/qrel.test.tsv', 'w', newline='') as fout:
        tsv_w = csv.writer(fout, delimiter='\t')
        for i in range(len(test_df)):
            score=test_df.loc[i,'score']
            q_id=test_df.loc[i, 'query-id']
            d_id=test_df.loc[i, 'corpus-id']
            tsv_w.writerow([q_id, 0, d_id,score])
            if i<300:
                print([q_id, 0, d_id,score])
            test_q_ids.append(str(q_id))
    # Test Qrels TREC
    with open(save_dir+dataset_name+'/qrel.test.trec', 'w') as fout:
            for line in open(save_dir+dataset_name+'/qrel.test.tsv'):
                fout.write(line.replace('\t', ' '))


    test_q_ids=list(set(test_q_ids))

    # Queries
    que_count=0
    with open(save_dir+dataset_name+'/queries.jsonl','r', encoding='utf-8') as fin:
        with open(save_dir+dataset_name+'/queries.test.tsv', 'w', newline='') as fout:
            tsv_w = csv.writer(fout, delimiter='\t')
            for item in jsonlines.Reader(fin): 
                _id = item['_id']
                text = item['text']
                if str(_id) in test_q_ids:
                    tsv_w.writerow([_id,text])
                    que_count+=1

    print([doc_count,que_count])

else:
    dirs=os.listdir(save_dir+dataset_name)
    for dir in dirs:
        dataset_name='cqadupstack/'+dir
        # Corpus
        doc_count=0
        with open(save_dir+dataset_name+'/corpus.jsonl','r', encoding='utf-8') as fin:
            with open(save_dir+dataset_name+'/corpus.tsv', 'w', newline='') as fout:
                tsv_w = csv.writer(fout, delimiter='\t')
                for item in jsonlines.Reader(fin): 
                    _id = item['_id'] 
                    title = item['title']
                    text = item['text']
                    tsv_w.writerow([_id, title, text])
                    doc_count+=1

        # Test Qrels TSV
        test_q_ids=[]
        test_df=pd.read_csv(save_dir+dataset_name+'/qrels'+'/test.tsv', sep='\t', header=0)
        with open(save_dir+dataset_name+'/qrel.test.tsv', 'w', newline='') as fout:
            tsv_w = csv.writer(fout, delimiter='\t')
            for i in range(len(test_df)):
                score=test_df.loc[i,'score']
                q_id=test_df.loc[i, 'query-id']
                d_id=test_df.loc[i, 'corpus-id']
                tsv_w.writerow([q_id, 0, d_id,score])
                if i<300:
                    print([q_id, 0, d_id,score])
                test_q_ids.append(str(q_id))
        # Test Qrels TREC
        with open(save_dir+dataset_name+'/qrel.test.trec', 'w') as fout:
                for line in open(save_dir+dataset_name+'/qrel.test.tsv'):
                    fout.write(line.replace('\t', ' '))


        test_q_ids=list(set(test_q_ids))

        # Queries
        que_count=0
        with open(save_dir+dataset_name+'/queries.jsonl','r', encoding='utf-8') as fin:
            with open(save_dir+dataset_name+'/queries.test.tsv', 'w', newline='') as fout:
                tsv_w = csv.writer(fout, delimiter='\t')
                for item in jsonlines.Reader(fin): 
                    _id = item['_id']
                    text = item['text']
                    if str(_id) in test_q_ids:
                        tsv_w.writerow([_id,text])
                        que_count+=1

        print([doc_count,que_count])