import re
import json
import os
import time
import argparse
from tqdm import trange, tqdm


def read_data(task: str, data_dir: str = None, split: str = "train"):
    """ Read data. Different task (i.e., dataset) has different read logic. """
    dataset = []
    if task == "table_mwp":
        # TableMWP is originally stored in a dictionary with pid as the key
        if data_dir is not None:
            input_file = os.path.join(data_dir, "table_mwp", f"{split}.json")
        else:
            input_file = f"./data/table_mwp/{split}.json"
        
        with open(input_file, "r") as r:
            data = json.load(r)
            for pid in data.keys():
                # add pid as an additional field
                data[pid]["pid"] = pid
                dataset.append(data[pid])
            print(f"read {len(data)} examples from {input_file}")
    else:
        if data_dir is not None:
            input_file = os.path.join(data_dir, f"{task}/{split}.json")
        else:
            input_file = f"./data/{task}/{split}.json"
        with open(input_file, "r") as r:
            data = json.load(r)
            for item in data:
                dataset.append(item)
            print(f"read {len(data)} examples from {input_file}")
    
    return dataset


def read_data_with_response(task: str, data_dir: str = None, split: str = "train"):
    """ Read data. Different task (i.e., dataset) has different read logic. """
    dataset = []
    if data_dir is not None:
        input_file = os.path.join(data_dir, f"reverse_aug_questions.jsonl")
        # read the jsonl file line by line
        with open(input_file, "r") as r:
            for line in r:
                dataset.append(json.loads(line))
        print(f"read {len(dataset)} examples from {input_file}")
    else:
        raise ValueError(f"Data directory is not provided for task {task}")

    return dataset


def reverse_aug_prompt_question(
    batch_examples,
    task: str = "strategy_qa",                  # Task name
    rule: str = None,                           # Rule to follow
    provide_answer: bool = False,               # Whether to provide the ground truth answer in the question
    backbone_model: str = "gemini-1.5-flash",
    response_key: str = "response",             # The key to store the response in the output
    num_generations: int = 1,                   # Number of samples to generate
    temperature: float = 0.8,                   # Temperature for sampling
    top_p: float = 1.0,                         # Top-p for sampling
    max_tokens: int = 1024,                     # Maximum number of tokens to generate
    
):  
    assert rule is not None, "Rule is required for reverse augmentation"

    # check the type of rule. If it is a dict, we need to format it as a string
    if isinstance(rule, dict):
        rule_str = "**"+rule["rule"]["title"]+"**\n"+rule["rule"]["content"]
        rule_str += "\n\n**Example:**\n"
        for example in rule["example"]:
            rule_str += "- Original: "+example["original"]+"\n"
            rule_str += "- Augmented: "+example["augmented"]+"\n\n"
    else:
        rule_str = rule

    if provide_answer:
        system_prompt = "Explanatory inversion is a technique that reformulates questions to emphasize their underlying explanations or reasoning processes. Your task is to reformulate the given question using the rule below, based on the input question and its correct answer. For multiple choice questions, you may also consider the provided answer choices if needed."+"\n\n"+"**Rule:**\n"+rule_str
    else:
        system_prompt = "Explanatory inversion is a technique that reformulates questions to emphasize their underlying explanations or reasoning processes. Your task is to reformulate the given question using the rule below, based on the input question. For multiple choice questions, you may also consider the provided answer choices if needed."+"\n\n"+"**Rule:**\n"+rule_str

    batch_prompts = []
    for x in batch_examples:
        # Construct the user prompt based on the task
        if task == "strategy_qa":
            answer_str = "Yes" if x["answer"] else "No"
            if provide_answer:
                user_prompt = "Original Question: Yes or no: "+x["question"]+"\n\n"+"Correct Answer: "+answer_str
            else:
                user_prompt = "Original Question: Yes or no: "+x["question"]
        
        elif task == "commonsense_qa" or task == "arc_challenge" or task == "date":
            choices = x["choices"]["text"]
            choicesKey = x["choices"]["label"]
            answer = x["answer"]
            answerKey = x["answerKey"]
            # combine the choices together with the keys
            choices_str = ""
            for i in range(len(choices)):
                choices_str += " (" + choicesKey[i] + ") " + choices[i]
            answer_str = " (" + answerKey + ") " + answer
            if provide_answer:
                user_prompt = "Original Question: "+x["question"]+"\n\n"+"Options: "+choices_str+"\n\n"+"Correct Answer: "+ answer_str
            else:
                user_prompt = "Original Question: "+x["question"]+"\n\n"+"Options: "+choices_str
        
        elif task == "table_mwp":
            table_title = x["table_title"] if "table_title" in x else ""
            table_content = x["table"]
            table_unit = x["unit"] if "unit" in x else ""
            choices = x["choices"]
            # Construct the table title and content
            if table_title is not None and table_title != "":
                table_prompt = "Table to use:\n[TITLE]: "+table_title+"\n"+table_content
            else:
                table_prompt = "Table to use:\n"+table_content
            # Concstruct the question's unit if provided
            if table_unit is not None and table_unit != "":
                unit_prompt = " (Unit: "+table_unit+")"
            else:
                unit_prompt = ""
            
            if choices is None:
                if provide_answer:
                    user_prompt = table_prompt+"\nOriginal Question: "+x["question"]+unit_prompt+"\nCorrect Answer: "+x["answer"]
                else:
                    user_prompt = table_prompt+"\nOriginal Question: "+x["question"]+unit_prompt
            else:
                choices_str = ""
                choicesKey = ['A', 'B', 'C', 'D', 'E', 'F']
                for i in range(len(choices)):
                    choices_str += " (" + choicesKey[i] + ") " + choices[i]
                if provide_answer:
                    user_prompt = table_prompt+"\nOriginal Question: "+x["question"]+unit_prompt+"\nOptions:"+choices_str+"\nCorrect Answer: "+x["answer"]
                else:
                    user_prompt = table_prompt+"\nOriginal Question: "+x["question"]+unit_prompt+"\nOptions:"+choices_str

        elif task == "gsm8k" or task == "math":
            if provide_answer:
                user_prompt = "Original Question: "+x["question"]+"\n\n"+"Correct Answer: "+x["answer"]
            else:
                user_prompt = "Original Question: "+x["question"]
        
        elif task == "anli":
            premise = x['premise']
            hypothesis = x['hypothesis']
            answer = x['label']
            if answer == "entailment":
                answer = "True"
            elif answer == "contradiction":
                answer = "False"
            else:
                answer = "Neither"
            if provide_answer:
                user_prompt = "Original Question: Given that \""+premise+"\"\nQuestion: "+hypothesis+" True, False, or Neither?\nCorrect Answer: "+answer
            else:
                user_prompt = "Original Question: Given that \""+premise+"\"\nQuestion: "+hypothesis+" True, False, or Neither?"
        
        else:
            raise ValueError(f"Task {task} is not supported for reverse augmentation")
        
        batch_prompts.append(system_prompt + "\n\n" + user_prompt)
    
    # Query the LLM
    max_retries = 1
    response = None
    for _ in range(max_retries):
        responses = batch_call_gemini_api(batch_prompts, model_name=backbone_model)
        if response is None:
            continue
        else:
            break
    
    # Apply responses to all examples in the batch
    for x, response in zip(batch_examples, responses):
        x[response_key] = response


def reverse_aug_prompt_response(
    batch_examples,
    task: str = "strategy_qa",                      # Task name
    backbone_model: str = "gemini-1.5-flash",
    question_key: str = "response",                 # The key to store the question in the output
    response_key: str = "response_rationale",       # The key to store the response in the output
    num_generations: int = 1,               # Number of samples to generate
    temperature: float = 0.8,               # Temperature for sampling
    top_p: float = 1.0,                     # Top-p for sampling
    max_tokens: int = 1024,                 # Maximum number of tokens to generate
):

    batch_prompts = []
    for x in batch_examples:
        # Construct the user prompt based on the task
        if task == "strategy_qa":
            if question_key == "question":
                user_prompt = "Question: "+x[question_key]+"\n\n"+"Please provide a step-by-step reasoning and conclude with either \"Yes\" or \"No\"."
            else:
                user_prompt = "Question: "+x[question_key]
        elif task == "commonsense_qa" or task == "arc_challenge" or task == "date":
            if question_key == "question":
                choices = x["choices"]["text"]
                choicesKey = x["choices"]["label"]
                # combine the choices together with the keys
                choices_str = ""
                for i in range(len(choices)):
                    choices_str += " (" + choicesKey[i] + ")" + choices[i]
                user_prompt = "Question: "+x["question"]+"\n\n"+"Choices: "+choices_str+"\n\n"+"Please provide a step-by-step reasoning and conclude with your choice."
            else:
                user_prompt = "Question: "+x[question_key]
        elif task == "table_mwp":
            if question_key == "question":
                table_prompt = "Table to use: "+x["table"]
                user_prompt = table_prompt + "\n\n" + "Question: "+x["question"]+"\n\n"+"Please provide a step-by-step reasoning and conclude with your final answer."
            else:
                table_prompt = "Table to use: "+x["table"]
                user_prompt = table_prompt + "\n\n" + "Question: "+x[question_key]
        elif task == "gsm8k" or task == "math":
            if question_key == "question":
                user_prompt = "Question: "+x["question"]+"\n\n"+"Please provide a step-by-step reasoning and conclude with your final answer."
            else:
                user_prompt = "Question: "+x[question_key]
        else:
            raise ValueError(f"Task {task} is not supported for reverse augmentation")
        
        batch_prompts.append(user_prompt)
    
    # Query the LLM
    max_retries = 1
    response = None
    for _ in range(max_retries):
        responses = batch_call_gemini_api(batch_prompts, model_name=backbone_model)
        if response is None:
            continue
        else:
            break
    
    # Apply responses to all examples in the batch
    for x, response in zip(batch_examples, responses):
        x[response_key] = response



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_process", type=int, default=10)
    parser.add_argument("--backbone", type=str, default="gemini-1.5-flash")
    parser.add_argument("--response_key", type=str, default="response")
    parser.add_argument("--data_dir", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default=None)
    parser.add_argument("--split", type=str, default="train")               # Split to use, default is "train", can be "test"
    # LLM generation parameters
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--num_generations", type=int, default=1)
    parser.add_argument("--temperature", type=float, default=0.8)           # Temperature for sampling, default is 0.8
    parser.add_argument("--top_p", type=float, default=1.0)                 # Top-p for sampling, default is 1.0
    parser.add_argument("--max_tokens", type=int, default=1024)             # Maximum number of tokens to generate, default is 1024
    # Task parameters
    parser.add_argument("--rule", type=str, default=None)                   # Path to the rule file when rule_type is "overall", otherwise this is the directory of the rule files
    parser.add_argument("--rule_type", type=str, default="per_task")        # Can be "per_task" or "overall"
    # Data augmentation task, default is "question-augmentation", can be "rationale-augmentation"
    parser.add_argument("--provide_answer", action="store_true",
                        help="If specified, include the correct answer in the prompt")
    parser.add_argument("--data_aug_task", type=str, default="question-augmentation")
    args = parser.parse_args()

    tasks = ["strategy_qa", "commonsense_qa", "arc_challenge", "date", "table_mwp", "gsm8k", "math", "anli"]
    # Load the rules
    if args.rule_type == "overall":
        assert args.rule is not None, "Rule is required for reverse augmentation"
        rules = {}
        for task in tasks:
            with open(os.path.join(args.rule, f"rules_{task}.jsonl"), "r") as f:
                rules[task] = [json.loads(line) for line in f]
            # Only keep the 'rule' field, example field are left to be used for further icl demonstrations
            rules[task] = [rule['rule'] for rule in rules[task]]
    elif args.rule_type == "per_task":
        assert args.rule is not None, "Rule is required for reverse augmentation"
        rules = {}
        for task in tasks:
            with open(os.path.join(args.rule, f"rules_{task}.jsonl"), "r") as f:
                rules[task] = [json.loads(line) for line in f]
    else:
        raise ValueError(f"Rule type {args.rule_type} is not supported")
    
    print(f"\nStarting augmentation with {len(tasks)} tasks...")
    
    if args.data_aug_task == "question-augmentation":
        for task in tasks:
            for rule_idx, rule in enumerate(rules[task]):
                print(f"\nStarting augmentation for task {task} with rule {rule_idx}")
                # Read data
                dataset = read_data(task, args.data_dir, args.split)
                # Split into batches
                batches = [dataset[i:i+args.batch_size] for i in range(0, len(dataset), args.batch_size)]
                # Apply reverse augmentation
                total_batches = len(batches)
                for batch in tqdm(batches, total=total_batches, desc=f"Processing {task} with rule {rule_idx}", unit="batch", ncols=100):
                    reverse_aug_prompt_question(batch, task=task, rule=rule, provide_answer=args.provide_answer,
                                                backbone_model=args.backbone, response_key=args.response_key,
                                                num_generations=args.num_generations, temperature=args.temperature,
                                                top_p=args.top_p, max_tokens=args.max_tokens)
                # Save the augmented dataset
                output_folder = os.path.join(args.output_dir, task, args.split, f"rule_{rule_idx}")
                os.makedirs(output_folder, exist_ok=True)
                output_file = f"{output_folder}/reverse_aug_questions.jsonl"
                with open(output_file, 'w') as w:
                    for item in dataset:
                        w.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    elif args.data_aug_task == "rationale-augmentation":
        for task in tasks:
            for rule_idx, rule in enumerate(rules):
                print(f"\nStarting generating augmented question's rationale for task {task} with rule {rule_idx}")
                # Read data
                dataset_dir = os.path.join(args.output_dir, task, f"rule_{rule_idx}")
                dataset = read_data_with_response(task, dataset_dir)
                # Split into batches
                batches = [dataset[i:i+args.batch_size] for i in range(0, len(dataset), args.batch_size)]
                # Apply reverse augmentation
                total_batches = len(batches)
                for batch in tqdm(batches, total=total_batches, desc=f"Processing {task} with rule {rule_idx}", unit="batch", ncols=100):
                    reverse_aug_prompt_response(batch, task=task,
                                                backbone_model=args.backbone, question_key="response",
                                                response_key="response_rationale", num_generations=args.num_generations,
                                                temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens)
                # Replace the original dataset file with the augmented dataset
                output_file = os.path.join(dataset_dir, f"reverse_aug_questions.jsonl")
                with open(output_file, 'w') as w:
                    for item in dataset:
                        w.write(json.dumps(item, ensure_ascii=False) + '\n')

        for task in tasks:
            print(f"\nStarting generating original question's rationale for task {task}")
            # Read data
            dataset_dir = os.path.join(args.output_dir, task, f"rule_{0}")
            dataset = read_data_with_response(task, dataset_dir)
            # Split into batches
            batches = [dataset[i:i+args.batch_size] for i in range(0, len(dataset), args.batch_size)]
            # Apply reverse augmentation
            total_batches = len(batches)
            for batch in tqdm(batches, total=total_batches, desc=f"Processing {task}", unit="batch", ncols=100):
                reverse_aug_prompt_response(batch, task=task,
                                            backbone_model=args.backbone, question_key="question",
                                            response_key="question_rationale", num_generations=args.num_generations,
                                            temperature=args.temperature, top_p=args.top_p, max_tokens=args.max_tokens)
            # Replace the original dataset file with the augmented dataset
            output_file = os.path.join(dataset_dir, f"reverse_aug_questions.jsonl")
            with open(output_file, 'w') as w:
                for item in dataset:
                    w.write(json.dumps(item, ensure_ascii=False) + '\n')

    else:
        raise ValueError(f"Data augmentation task {args.data_aug_task} is not supported")
