# script: 
# python3.11 explain_inconsistencies/generate.py --n 10

# Generated examples obey the following rules:

from utils import load_jsonl, load_tsv, make_chat_call, load_json
import argparse, random
import nltk
# Check if 'punkt' is already downloaded
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
from nltk.tokenize import sent_tokenize
from pathlib import Path
import json
from rich import print
from openai import AzureOpenAI

def filter_dataset(args, dataset):
    # only keep examples with gold_label == "entailment", and only keep the first args.n
    filtered_dataset = [example for example in dataset if example["gold_label"] == "entailment"]
    
    if args.difficulty == "comma":
        # only keep examples where both sentence1 and sentence2 have no commas
        filtered_dataset = [example for example in filtered_dataset 
                            if ',' not in example["sentence1"] 
                            and ',' not in example["sentence2"]]
        print(f"Filtered dataset to {len(filtered_dataset)} examples with no commas")

    elif args.difficulty == "wordcount":
        print(f"Filtering by word count: {args.wordcount}")
        # only keep examples where both sentence1 and sentence2 are at most args.wordcount words, and only one sentence
        filtered_dataset = [example for example in filtered_dataset 
                            if len(example["sentence1"].split()) <= args.wordcount
                            and len(example["sentence2"].split()) <= args.wordcount
                            and len(sent_tokenize(example["sentence1"])) == 1
                            and len(sent_tokenize(example["sentence2"])) == 1]
        print(f"Filtered dataset to {len(filtered_dataset)} examples with {args.wordcount} words or less and single sentences")

    elif args.difficulty == "sentence":
        # only keep examples where both sentence1 and sentence2 are at most 1 sentence
        filtered_dataset = [example for example in filtered_dataset 
                            if len(sent_tokenize(example["sentence1"])) <= 1 
                            and len(sent_tokenize(example["sentence2"])) <= 1]
        print(f"Filtered dataset to {len(filtered_dataset)} examples with single sentences")
    
    elif args.difficulty == "none":
        print("No difficulty filtering applied")

    # shuffle the dataset
    random.shuffle(filtered_dataset)
    
    filtered_dataset = filtered_dataset[:args.n]
    print(f"Filtered dataset to {len(filtered_dataset)} examples with gold_label == 'entailment'")
    return filtered_dataset

def construct_prompt(sentence1, sentence2):
    prefix = "The following are two statements:"

    # remove the period from sentence1
    if sentence1[-1] == ".": 
        sentence1 = sentence1[:-1]
    sentence1 = sentence1.strip()
    
    # decapitalize sentence1
    if sentence1[0].isupper():
        sentence1_decap = sentence1[0].lower() + sentence1[1:]
    else:
        sentence1_decap = sentence1
    
    # decapitalize sentence2
    if sentence2[0].isupper():
        sentence2 = sentence2[0].lower() + sentence2[1:]
    sentence2 = sentence2.strip()

    statement1 = f"If {sentence1_decap}, then it is always the case that {sentence2}"
    no_conflict_statement2 = f"It is not the case that {sentence2}"
    conflict_statement2 = f"{sentence1}, but it is not the case that {sentence2}"
    
    if args.fix_grammar:
        statement1 = fix_grammar(args, client, statement1)
        no_conflict_statement2 = fix_grammar(args, client, no_conflict_statement2)
        conflict_statement2 = fix_grammar(args, client, conflict_statement2)

    zero_shot_suffix = 'Can both of these statements, as explicitly stated, be true at the same time? Please ONLY answer with "Yes" or "No".'
    cot_suffix = 'Can both of these statements, as explicitly stated, be true at the same time? Please reason about your answer and then answer "Yes" or "No".'
    cot_explicit_suffix = 'Can both of these statements, as explicitly stated, be true at the same time? Please first explain why statement 2 could be true and then answer "Yes" or "No".'

    zero_shot_no_conflict = prefix + "\n1. " + statement1 + "\n2. " + no_conflict_statement2 + "\n\n" + zero_shot_suffix
    zero_shot_conflict = prefix + "\n1. " + statement1 + "\n2. " + conflict_statement2 + "\n\n" + zero_shot_suffix
    
    cot_no_conflict = prefix + "\n1. " + statement1 + "\n2. " + no_conflict_statement2 + "\n\n" + cot_suffix
    cot_conflict = prefix + "\n1. " + statement1 + "\n2. " + conflict_statement2 + "\n\n" + cot_suffix

    cot_explicit_no_conflict = prefix + "\n1. " + statement1 + "\n2. " + no_conflict_statement2 + "\n\n" + cot_explicit_suffix
    cot_explicit_conflict = prefix + "\n1. " + statement1 + "\n2. " + conflict_statement2 + "\n\n" + cot_explicit_suffix

    return {"zero_shot_no_conflict": zero_shot_no_conflict, 
            "zero_shot_conflict": zero_shot_conflict, 
            "cot_no_conflict": cot_no_conflict, 
            "cot_conflict": cot_conflict, 
            "cot_explicit_no_conflict": cot_explicit_no_conflict, 
            "cot_explicit_conflict": cot_explicit_conflict}

def fix_grammar(args, client, text):
    # Use an LLM to verify and fix grammar of text
    prompt = "Fix the grammar of the following text but do not change its meaning. If the text is correct, just say 'correct'.\n" + text
    response = make_chat_call(client, args.model, prompt, max_tokens=200).choices[0].message.content
    if args.verbose: print(f"Input: {prompt}\n\nEdited: {response}")
    if "correct" in response and len(response) < 10:
        return text
    return response

def construct_prompts(args, client, dataset):
    prompts = []
    for example in dataset:
        prompts.append(construct_prompt(
            example["sentence1"],
            example["sentence2"]
        ))
        
    print(f"Constructed {len(prompts)} prompts")
    return prompts
    

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="dataset")
    parser.add_argument("--n", type=int, default=1000)
    parser.add_argument("--dataset_name", type=str, default="rte", choices=["snli", "mnli", "rte", "gpt_generated"])
    parser.add_argument("--difficulty", type=str, default="wordcount", choices=["wordcount", "sentence", "comma", "none"])
    parser.add_argument("--wordcount", type=int, default=7)
    parser.add_argument("--model", type=str, default="gpt-4o")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--fix_grammar", action="store_true")
    args = parser.parse_args()

    api_key = "redacted"
    api_base= "redacted"
    client = AzureOpenAI(api_key = api_key,  
                            api_version = "2023-05-15",
                            azure_endpoint = api_base)

    if args.dataset_name == "gpt_generated":
        dataset = load_json("explain_inconsistencies/GPT_generated.json")
    elif args.dataset_name == "snli":
        path1 = "explain_inconsistencies/snli_1.0/snli_1.0_train.jsonl"
        path2 = "explain_inconsistencies/snli_1.0/snli_1.0_dev.jsonl"
        path3 = "explain_inconsistencies/snli_1.0/snli_1.0_test.jsonl"
        dataset = load_jsonl([path1, path2, path3])
    elif args.dataset_name == "mnli":
        path4 = "explain_inconsistencies/GLUE/MNLI/dev_matched.tsv"
        path5 = "explain_inconsistencies/GLUE/MNLI/dev_mismatched.tsv"
        dataset = load_tsv([path4, path5])
    elif args.dataset_name == "rte":
        path6 = "explain_inconsistencies/GLUE/RTE/train.tsv"
        path7 = "explain_inconsistencies/GLUE/RTE/dev.tsv"
        dataset = load_tsv([path6, path7])
        # replace key "label" with "gold_label"
        for example in dataset:
            # print(f"example: {example}")
            example["gold_label"] = example["label"]
            del example["label"]

    dataset = filter_dataset(args, dataset)
    prompts = construct_prompts(args, client, dataset)

    print(prompts[0])
    print(prompts[1])
    # run prompts here
    
    output_dir = Path(args.output_dir, args.dataset_name)
    data_size = len(prompts)
    output_dir.mkdir(parents=True, exist_ok=True)
    if args.difficulty == "none":
        output_file = output_dir / f"datasets_{data_size}_{args.dataset_name}.json"
    else:
        output_file = output_dir / f"datasets_{data_size}_{args.dataset_name}_{args.difficulty}_{args.wordcount}.json"
    print(f"Saving prompts to {output_file}")
    with open(output_file, "w") as f:
        json.dump(prompts, f, indent=2)
    
