import json
import sys
import os
from tqdm import tqdm
import time

# Append the parent directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.model_utils import generate_gemini_responses_threaded_modified
from utils.string_utils import extract_list_from_code, extract_dict_from_string

# Load the full dataset of facts.
with open("../../data/full_dataset_2_25_newest.json", 'r') as f:
    data = json.load(f)
facts = list(data.keys())

N_PARAPHRASES = 1000
N_PARAPHRASES_PER_PROMPT = 30
PROMPT = """
Paraphrase this fact concisely:
{fact}

Return {n} paraphrases for this fact in python list format.
Do not return anything else other than the paraphrase.
"""

# Path to the paraphrases file.
paraphrases_file = "../../data/fact_paraphrases_1000.json"

# Try to load existing paraphrases if available.
if os.path.exists(paraphrases_file):
    with open(paraphrases_file, 'r') as f:
        loaded = json.load(f)
    # Convert each list to a set for easy updating.
    fact_paraphrases = {fact: set(paras) for fact, paras in loaded.items()}
else:
    fact_paraphrases = {}

# Process each fact.
for fact in tqdm(facts, desc="Processing facts"):
    # Skip facts that already have enough paraphrases.
    if fact in fact_paraphrases and len(fact_paraphrases[fact]) >= N_PARAPHRASES:
        continue

    # Initialize if not already present.
    if fact not in fact_paraphrases:
        fact_paraphrases[fact] = set()

    prompt_for_this_fact = PROMPT.format(fact=fact, n=N_PARAPHRASES_PER_PROMPT)

    # Create a progress bar for accumulating paraphrases for this fact.
    pbar = tqdm(total=N_PARAPHRASES, desc="Generating paraphrases", leave=False)
    pbar.update(len(fact_paraphrases[fact]))
    retry_cnt = 0
    while len(fact_paraphrases[fact]) < N_PARAPHRASES and retry_cnt < 5:
        need_n_more = N_PARAPHRASES - len(fact_paraphrases[fact])
        responses = generate_gemini_responses_threaded_modified(
            [prompt_for_this_fact] * 50,
            model_name="gemini-1.5-pro",
            max_tokens=800,
            temperature=2.0,
            max_workers=30,
            wait_time=10
        )
        all_responses = []
        for response in responses:
            try:
                response_list = extract_list_from_code(response)
            except Exception:
                response_list = []
            if response_list:
                all_responses.extend([resp.strip() for resp in response_list])

        if not all_responses:
            continue

        prev_count = len(fact_paraphrases[fact])
        fact_paraphrases[fact].update(all_responses)
        new_count = len(fact_paraphrases[fact])
        pbar.update(new_count - prev_count)

        # Trim to exactly N_PARAPHRASES if we overshoot.
        if new_count > N_PARAPHRASES:
            fact_paraphrases[fact] = set(list(fact_paraphrases[fact])[:N_PARAPHRASES])
        retry_cnt+=1
        print(f"RETRYING {retry_cnt} time...")

    pbar.close()

    # Print current status for the processed fact.
    print(fact)
    print(len(fact_paraphrases[fact]))
    print()

    # Immediately save the updated paraphrases to file.
    with open(paraphrases_file, 'w') as f:
        # Convert each set back to a list for JSON serialization.
        json.dump({fact: list(paras) for fact, paras in fact_paraphrases.items()}, f, indent=2)
