import os
import json
from dotenv import load_dotenv
import google.generativeai as genai
import time


def upsample_question(model, original_question, original_answer):
    """
    Uses Gemini to transform quantitative epidemiology data into sophisticated open-ended questions
    that require deep analysis of RAG information and cannot be easily answered without simulation tools.
    Creates challenging questions for testing LLMs with and without access to simulation capabilities.
    """
    prompt = f"""
    You are an expert epidemiologist at the CDC, tasked with transforming a structured simulation summary into a comprehensive, claim-based forecast. Your goal is to create a challenging open-ended question and a structured, multi-claim answer suitable for advanced AI evaluation.

    **Original Quantitative Question:** "{original_question}"
    **Structured Simulation Outlook (JSON):** "{original_answer}"

    **CRITICAL REQUIREMENTS:**
    1.  **Preserve Core Inputs in the Question:** The new `open_question` must retain all key forecasting parameters from the original question: R0, seasonality, start date, target states, and prior immunity.
    2.  **Frame as a Comprehensive Forecasting Question:** The `open_question` must ask for a single, comprehensive "outlook" or "forecast." It MUST NOT be a list of sub-questions. It should guide the answer by asking to consider multiple perspectives, for example: "...what is the comprehensive epidemiological forecast? Please consider the peak's severity, its timing, and the initial speed of the outbreak."
    3.  **Generate a Multi-Claim, Parallel-Structured Answer:** The `open_answer` MUST be a single, natural-flowing paragraph. Internally, it must be structured with **several parallel, independent, verifiable claims**, followed by a **final concluding synthesis**.
        * **Claim 1:** A statement about the initial trend based on given simulation data.
        * **Claim 2:** A statement about the peak magnitude based on given simulation data.
        * **Claim 3:** A statement about the peak timing based on given simulation data.
        * **Synthesis:** A concluding sentence that summarizes the overall situation.
    4.  **Selectively Include Quantitative Details:** The `open_answer` MUST include a few (**maximum of three**) of the most important quantitative values from the simulation outlook to add precision to its claims. These values should be woven naturally into the text. Do not simply list numbers; use them to support the qualitative assessment.
    5.  **Focus on Forecasting, Not Policy:** The answer must be strictly about the predicted epidemiological trajectory. AVOID any language about policy interventions or resource allocation.
    6.  Maintain the length of the answer similar to the examples.

    ---
    **HIGH-QUALITY EXAMPLES:**

    **Example 1: High Immunity / Suppressed Outbreak**

    * **Input `original_answer` (JSON):**
        ```json
        "**...**\n{{
        \"peak_magnitude\": {{
            \"median_value\": 8.0,
            \"90_percent_ci\": [2.0, 19.0]
        }},
        \"peak_timing\": {{
            \"peak_date\": \"2022-10-01\",
            \"weeks_from_start\": 0.0
        }},
        \"initial_trend\": {{
            \"qualitative_trend\": \"declining\",
            \"four_week_growth_percentage\": -75.0
        }},
        \"total_burden\": {{
            \"cumulative_median_value\": 151.0
        }}
        }}"
        ```
    * **Your Output (JSON):**
        ```json
        {{
        "open_question": "For an upcoming influenza season in New York, South Carolina, Minnesota, and Maryland, assuming a start date of October 1st, an R0 of 1.85, high seasonality, and a robust 40% prior population immunity, what is the comprehensive epidemiological forecast for hospital incidence? Please assess the expected trajectory, including the initial trend, peak severity, and timing.",
        "open_answer": "The forecast indicates a very low-risk season where an outbreak is effectively suppressed. The season is expected to begin with a sharp declining trend, showing a 75.0% change in the first four weeks. The peak of the outbreak is projected to be negligible in magnitude and is forecast to occur immediately around the start date of October 1st. This trajectory suggests the 40% prior immunity level is sufficient to prevent any significant seasonal wave."
        }}
        ```

    ---
    **Example 2: High R0 / Explosive Outbreak**

    * **Input `original_answer` (JSON):**
        ```json
        "**...**\n{{
        \"peak_magnitude\": {{
            \"median_value\": 479.0,
            \"90_percent_ci\": [138.0, 910.0]
        }},
        \"peak_timing\": {{
            \"peak_date\": \"2022-12-14\",
            \"weeks_from_start\": 6.7
        }},
        \"initial_trend\": {{
            \"qualitative_trend\": \"rapid growth\",
            \"four_week_growth_percentage\": 4012.5
        }},
        \"total_burden\": {{
            \"cumulative_median_value\": 18123.0
        }}
        }}"
        ```
    * **Your Output (JSON):**
        ```json
        {{
        "open_question": "What is the projected epidemiological outlook for hospital prevalence in Missouri, Washington, Ohio, and Georgia, given a scenario with a concerning R0 of 2.1, no seasonality, a late October start date, and 30% prior immunity? Analyze the key forecasting characteristics of this scenario.",
        "open_answer": "The epidemiological outlook for this scenario appears severe and explosive. The season is forecast to begin with an extremely rapid growth phase, with an initial four-week growth rate exceeding 4000%. The peak of hospitalizations is expected to be high, reaching a median value of approximately 479 cases. This severe peak is projected to materialize moderately late in the season, around 6.7 weeks after onset, pointing to a dangerous and sustained wave that would heavily strain healthcare resources."
        }}
        ```

    ---
    **Example 3: Low R0 / Mild Outbreak**

    * **Input `original_answer` (JSON):**
        ```json
        "**...**\n{{
        \"peak_magnitude\": {{
            \"median_value\": 15.5,
            \"90_percent_ci\": [5.0, 29.0]
        }},
        \"peak_timing\": {{
            \"peak_date\": \"2022-12-11\",
            \"weeks_from_start\": 6.3
        }},
        \"initial_trend\": {{
            \"qualitative_trend\": \"rapid growth\",
            \"four_week_growth_percentage\": 80.0
        }},
        \"total_burden\": {{
            \"cumulative_median_value\": 738.5
        }}
        }}"
        ```
    * **Your Output (JSON):**
        ```json
        {{
        "open_question": "For a flu season in California and Washington starting around October 28th, what is the forecasted outlook for hospital incidence given a low R0 of 1.85, no seasonality, and 40% prior immunity? Please assess the expected trajectory in terms of its peak, timing, and initial trend.",
        "open_answer": "The forecast suggests a mild and manageable season. Although the season is expected to begin with a rapid initial growth of 80.0% in the first month, the overall peak magnitude of hospitalizations is projected to be low, with a median of only about 16 cases. The peak of this mild wave is anticipated to arrive moderately late in the season, approximately 6.3 weeks after it begins. This outlook indicates a fast start but a low overall burden on the healthcare system."
        }}
        ```
    ---

    Transform the provided quantitative data into a similarly structured question and a multi-claim answer. Ensure the answer is a single, natural-flowing paragraph that contains independent claims and is supported by a maximum of three key quantitative values from the source data.
    """

    try:
        response = model.generate_content(prompt)
        cleaned_response = response.text.strip().replace(
            "```json", "").replace("```", "").strip()
        upsampled_pair = json.loads(cleaned_response)
        return upsampled_pair
    except (ValueError, json.JSONDecodeError) as e:
        print(
            f"Error decoding JSON from Gemini response: {e}\nRaw response: {response.text}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return None


def main():
    """
    Main function to load generated epidemiology questions and upsample them.
    """
    dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
    load_dotenv(dotenv_path=dotenv_path)

    api_key = os.getenv("Google_API_KEY")
    if not api_key:
        raise ValueError(
            "Google_API_KEY not found. Please create a .env file in src/Epidemiology/ with your key.")

    genai.configure(api_key=api_key)
    model = genai.GenerativeModel('gemini-2.5-pro')

    try:
        with open("generated_questions.json", 'r', encoding='utf-8') as f:
            original_questions = json.load(f)
    except FileNotFoundError:
        print("Error: generated_questions.json not found. Please run create_open.py first.")
        return

    # Test with just 5 questions first
    questions_to_process = original_questions  # Process all questions
    upsampled_questions = []
    target_count = len(questions_to_process) * 2

    print(
        f"Upsampling {len(questions_to_process)} questions into {target_count} open-ended Q&A pairs...")

    for i, qa_pair in enumerate(questions_to_process):
        for j in range(2):
            print(
                f"Processing question {i+1}/{len(questions_to_process)}, attempt {j+1}/2...")
            new_pair = None
            retries = 3
            for attempt in range(retries):
                new_pair = upsample_question(
                    model, qa_pair['question'], qa_pair['answer'])
                if new_pair and 'open_question' in new_pair and 'open_answer' in new_pair:
                    upsampled_questions.append(new_pair)
                    print(
                        f"  -> Success. Total generated: {len(upsampled_questions)}/{target_count}")
                    break
                else:
                    print(
                        f"  -> Failed to generate a valid pair (attempt {attempt+1}/{retries}), retrying in 30 seconds...")
                    time.sleep(30)
            if new_pair is None:
                print(
                    f"  -> Could not generate a valid pair for question {i+1} after {retries} attempts. Skipping.")

    output_file = "open_questions.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(upsampled_questions, f, ensure_ascii=False, indent=4)

    print(
        f"\nSuccessfully generated and saved {len(upsampled_questions)} open-ended Q&A pairs to {output_file}")


if __name__ == "__main__":
    main()
