# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""
Create a dataset jsonl file for QA task.

python qa.py \
    --save_dir=./ \
    --save_name=niah_single \
    --tokenizer_path=tokenizer.model \
    --tokenizer_type=nemo \
    --max_seq_length=4096 \
    --tokens_to_generate=128 \
    --num_samples=10 \
    --template="Answer the question based on the given documents. Only give me the answer and do not output any other words.\n\nThe following are given documents.\n\n{context}\n\nAnswer the question based on the given documents. Only give me the answer and do not output any other words.\n\nQuestion: {query} Answer:"
"""
import os
import re
import json
import argparse
from pathlib import Path
from tqdm import tqdm
import random
import numpy as np
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
import sys
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) 
from tokenizer import select_tokenizer


parser = argparse.ArgumentParser()
# Basic Configurations
parser.add_argument("--src_data_dir_path", type=Path, required=True, help='dataset file path')
parser.add_argument("--save_dir", type=Path, required=True, help='dataset folder to save dataset')
parser.add_argument("--save_name", type=str, required=True, help='name of the save dataset jsonl file')
parser.add_argument("--subset", type=str, default='validation', help='Options: validation or test')
parser.add_argument("--tokenizer_path", type=str, required=True, help='path to the tokenizer model')
parser.add_argument("--tokenizer_type",  type=str, default='nemo', help='[Options] nemo, hf, openai.')
parser.add_argument("--max_seq_length", type=int, required=True, help='max sequence length including all input tokens and generated tokens.')
parser.add_argument("--tokens_to_generate", type=int, required=True, help='expected generated token amount.')
parser.add_argument("--num_samples", type=int, required=True, help='number of samples to generate')
parser.add_argument("--pre_samples", type=int, default=0, help='number of samples are already generated')
parser.add_argument("--random_seed", type=int, default=42)
parser.add_argument("--template", type=str, required=False, help='prompt template')
parser.add_argument("--remove_newline_tab", action='store_true', help='remove `\n` and `\t` in all strings.')

# Complexity Configurations
parser.add_argument("--dataset", type=str, required=True, help='name of the dataset file')


args = parser.parse_args()
random.seed(args.random_seed)
np.random.seed(args.random_seed)

# Load Tokenizer
TOKENIZER = select_tokenizer(args.tokenizer_type, args.tokenizer_path)

# 1. load_dataset
# Read SQuAD QA dataset
def read_squad(file):
    with open(file) as f:
        data = json.load(f)
        
    total_docs = [p['context'] for d in data['data'] for p in d['paragraphs']]
    total_docs = sorted(list(set(total_docs)))
    total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}

    total_qas = []
    for d in data['data']:
        more_docs = [total_docs_dict[p['context']] for p in d['paragraphs']]
        for p in d['paragraphs']:
            for qas in p['qas']:
                if not qas['is_impossible']:
                    total_qas.append({
                        'query': qas['question'],
                        'outputs': [a['text'] for a in qas['answers']],
                        'context': [total_docs_dict[p['context']]],
                        'more_context': [idx for idx in more_docs if idx != total_docs_dict[p['context']]]
                    })
                        
    return total_qas, total_docs

# Read Hotpot QA dataset
def read_hotpotqa(file):
    with open(file) as f:
        data = json.load(f)

    total_docs = [f"Passage \"{t}\":\n{''.join(p)}" for d in data for t, p in d['context']]
    total_docs = sorted(list(set(total_docs)))
    total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
    
    total_qas = []
    for d in data:
        total_qas.append({
            'query': d['question'],
            'outputs': [d['answer']],
            'context': [total_docs_dict[f"Passage \"{t}\":\n{''.join(p)}"] for t, p in d['context']],
        })
        
    return total_qas, total_docs

def read_musique(file):
    data = []
    total_docs = []
    with open(file) as f:
        for line in f.readlines():
            item = json.loads(line)
            data.append(item)
            for d in item["paragraphs"]:
                title = d["title"]
                paragraph_text = d["paragraph_text"]
                total_docs.append(f"Passage \"{title}\":\n{paragraph_text}")
    total_docs = sorted(list(set(total_docs)))
    total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
    total_qas = []
    for d in data:
        context_idx_list = []
        for para in d["paragraphs"]:
            title = para["title"]
            paragraph_text = para["paragraph_text"]
            is_supporting = para["is_supporting"]
            if not is_supporting:
                continue
            context_idx_list.append(total_docs_dict[f"Passage \"{title}\":\n{paragraph_text}"])
        qas = {
            'query': d['question'],
            'outputs': [d['answer']] + d['answer_aliases'],
            'context': context_idx_list,
        }
        if "musique_1475" in file or "musique_qwen" in file:
            qas["positive_sample"] = d["positive_sample"]
            qas["negative_sample"] = d["negative_sample"]
        total_qas.append(qas)
    return total_qas, total_docs

def read_context_synthesis_qa(file_path):
    data = []
    total_docs = []
    with open(file_path) as f:
        for line in f.readlines():
            item = json.loads(line)
            data.append(item)
            relevant_context = item["synthesized_context"]
            total_docs.append(relevant_context)
    total_docs = sorted(list(set(total_docs)))
    total_docs_dict = {c: idx for idx, c in enumerate(total_docs)}
    total_qas = []
    for d in data:
        context_idx_list = []
        relevant_context = d["synthesized_context"]
        context_idx_list.append(total_docs_dict[relevant_context])
        qas = {
            'query': d['instruction'],
            'outputs': [d['answer']],
            'context': context_idx_list,
        }
        total_qas.append(qas)
    return total_qas, total_docs


DOCUMENT_PROMPT = "Document {i}:\n{document}"
data_file_path = os.path.join(args.src_data_dir_path, args.dataset)

print(f'loading dataset form:\n {data_file_path}')

if 'squad' in args.dataset:
    QAS, DOCS = read_squad(data_file_path)
elif 'hotpotqa' in args.dataset:
    QAS, DOCS = read_hotpotqa(data_file_path)
elif 'musique' in args.dataset:
    QAS, DOCS = read_musique(data_file_path)
else:
    raise NotImplementedError(f'{args.dataset} is not implemented.')

print(f"Dataset: {args.dataset}\tQAS num:{len(QAS)}\tDOCS num:{len(DOCS)}")
def generate_input_output(index, num_docs):
    curr_q = QAS[index]['query']
    curr_a = QAS[index]['outputs']
    curr_docs = QAS[index]['context']
    curr_more = QAS[index].get('more_context', [])
    curr_response = (QAS[index].get('positive_sample', None), QAS[index].get('negative_sample', None))
    if num_docs < len(DOCS):
        if (num_docs - len(curr_docs)) > len(curr_more):
            addition_docs = [i for i, d in enumerate(DOCS) if i not in curr_docs + curr_more]
            all_docs = curr_docs + curr_more + random.sample(addition_docs, max(0, num_docs - len(curr_docs) - len(curr_more)))
        else:
            try:
                all_docs = curr_docs + random.sample(curr_more, num_docs - len(curr_docs))
            except:
                all_docs = curr_docs + curr_more
                pass
    
        all_docs = [DOCS[idx] for idx in all_docs]
    else:
        all_docs = DOCS
        
    random.Random(args.random_seed).shuffle(all_docs)
    
    context = '\n\n'.join([DOCUMENT_PROMPT.format(i=i+1, document=d) for i, d in enumerate(all_docs)])
    input_text = args.template.format(
        context=context, 
        query=curr_q
    )
    return input_text, curr_docs, curr_a, all_docs, curr_q, curr_response

def generate_samples(num_samples: int, max_seq_length: int, save_dir: str, incremental: int = 10): 
    
    write_jsons = []
    tokens_to_generate = args.tokens_to_generate
    
    # Find the perfect num_docs
    num_docs = incremental
    
    total_tokens = 0  # Track the total tokens generated for this example
    while total_tokens + tokens_to_generate < max_seq_length :  
        # input_text, curr_docs, answer = generate_input_output(0, num_docs)
        input_text, curr_docs, answer, all_docs, curr_q, curr_response = generate_input_output(0, num_docs)
        # Calculate the number of tokens in the example
        total_tokens = len(TOKENIZER.text_to_tokens(input_text + f' {answer}'))
        
        print(f'Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Docs: {num_docs}')
        if total_tokens + tokens_to_generate > max_seq_length:
            num_docs -= incremental
            break
            
        num_docs += incremental
        if num_docs > len(DOCS):
            num_docs = len(DOCS)
            break
    print('Number of documents:', num_docs)
    
    # Generate samples
    for index in tqdm(range(num_samples)):
        used_docs = num_docs
        while(True):
            try:
                # input_text, curr_docs, answer = generate_input_output(index + args.pre_samples, used_docs)
                input_text, curr_docs, answer, all_docs, curr_q, curr_response = generate_input_output(index + args.pre_samples, used_docs)

                length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
                assert length <= max_seq_length, f"{length} exceeds max_seq_length."
                break
            except:
                if used_docs > incremental:
                    used_docs -= incremental
        
        if args.remove_newline_tab:
            input_text = ' '.join(input_text.replace('\n', ' ').replace('\t', ' ').strip().split())
        
        formatted_output = {
            "index": index,
            "all_docs": all_docs,
            "question": curr_q,
            "outputs": answer,
            "length": length
        }
        if curr_response[0]:
            formatted_output["positive_sample"] = curr_response[0]
            formatted_output["negative_sample"] = curr_response[1]
        write_jsons.append(formatted_output)

    return write_jsons

def main():
    save_file = args.save_dir / f'{args.save_name}' if args.save_name.endswith(".jsonl") else args.save_dir / f'{args.save_name}.jsonl'
    save_file.parent.mkdir(parents=True, exist_ok=True)

    write_jsons = generate_samples(
        num_samples=args.num_samples, 
        max_seq_length=args.max_seq_length, 
        save_dir=args.save_dir
    )
    
    write_manifest(save_file, write_jsons)
    print(f"save to {save_file}")

if __name__=="__main__":
    main()
