import gleam_ai
from gleam_ai.utils import set_seed
import os
import json
from dotenv import load_dotenv
import google.generativeai as genai
import time
import traceback
import sys
import random
import numpy as np
import pandas as pd
from datetime import datetime, timedelta

# Add the gleam-ai-shared directory to path
gleam_ai_path = os.path.join(os.path.dirname(
    os.path.abspath(__file__)), 'gleam-ai-shared')
sys.path.append(gleam_ai_path)


# Import GLEAM-AI utilities

# Seasonality level mappings (as specified by the epidemiology expert)
SEASONALITY_MAPPING = {
    'no seasonality': 1.0,
    'moderate seasonality': 0.75,
    'high seasonality': 0.5
}

# Available R0 values (expert specified range: 1.75 to 3.0)
R0_VALUES = [1.75, 1.85, 2.0, 2.1, 2.2, 2.4, 2.6, 2.8, 3.0]

# Available seasonality levels
SEASONALITY_LEVELS = list(SEASONALITY_MAPPING.keys())

# Prior immunity levels (as specified by epidemiology expert)
PRIOR_IMMUNITY_LEVELS = [0.10, 0.20, 0.30, 0.40, 0.50, 0.60]

# Target metrics mapping (removed latent prevalence as requested)
TARGET_METRICS_MAPPING = {
    'hospital incidence': 'hosp_inc',
    'hospital prevalence': 'hosp_prev',
    'latent incidence': 'latent_inc'
}

# US States (must match exactly with create_open.py for proper state ID mapping)
US_STATES = ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado',
            'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho',
            'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine',
            'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri',
            'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', 'New Mexico',
            'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon',
            'Pennsylvania', 'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee',
            'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington', 'West Virginia', 'Wisconsin',
            'Wyoming', 'District of Columbia']


def initialize_gleam_model():
    """Initialize the GLEAM-AI model for simulation."""
    try:
        print("🔧 Initializing GLEAM-AI model...")

        # Change to gleam-ai-shared directory
        original_dir = os.getcwd()
        os.chdir(
            '/home/ubuntu/Adapting-While-Learning/src/Epidemiology2/gleam-ai-shared')

        # Load the model - it will auto-detect the best device (cuda/mps/cpu)
        model = gleam_ai.load_pretrained_model(
            config_path='config.yaml',
            device='cuda'  # or 'cpu' if no GPU available
        )

        # Change back to original directory
        os.chdir(original_dir)

        print("✅ GLEAM-AI model initialized successfully!")
        return model

    except Exception as e:
        print(f"❌ Failed to initialize GLEAM-AI model: {e}")
        import traceback
        traceback.print_exc()
        return None


def generate_starting_date():
    """Generate a random starting date within the specified 2-month range (Sep-Nov for flu season)."""
    # Use 2022 flu season as base
    start_range = datetime(2022, 9, 1)  # September 1
    end_range = datetime(2022, 11, 30)   # November 30

    # Random number of days within the range
    days_range = (end_range - start_range).days
    random_days = random.randint(0, days_range)

    random_date = start_range + timedelta(days=random_days)
    return random_date.strftime('%Y-%m-%d')


def extract_simulation_summary(df_sims, target_metric, target_states, start_date_str):
    """
    Extracts a rich, PURELY quantitative outlook from the daily simulation results.
    Matches the structure used in create_open.py for consistency.
    """
    try:
        # Check if required columns exist
        required_columns = ['compartment', 'state_id', 'date', 'value']
        missing_columns = [col for col in required_columns if col not in df_sims.columns]
        if missing_columns:
            print(f"  ❌ Missing required columns: {missing_columns}")
            print(f"  📊 Available columns: {list(df_sims.columns)}")
            return {}

        compartment = TARGET_METRICS_MAPPING[target_metric]

        # Check if compartment exists in data
        available_compartments = df_sims['compartment'].unique()
        if compartment not in available_compartments:
            print(f"  ❌ Compartment {compartment} not found in simulation data")
            return {}

        state_ids_to_analyze = [US_STATES.index(s) for s in target_states if s in US_STATES]
        print(f"  📊 State IDs to analyze: {state_ids_to_analyze}")

        # Check if state IDs exist in data
        available_state_ids = df_sims['state_id'].unique()
        valid_state_ids = [sid for sid in state_ids_to_analyze if sid in available_state_ids]
        if not valid_state_ids:
            print(f"  ❌ No valid state IDs found in simulation data")
            return {}

        metric_data = df_sims[
            (df_sims['compartment'] == compartment) &
            (df_sims['state_id'].isin(valid_state_ids))
        ]

        print(f"  📊 Filtered metric data shape: {metric_data.shape}")
        if metric_data.empty:
            print(f"  ❌ No data found for compartment {compartment} and states {valid_state_ids}")
            return {}

        # Aggregate across all selected states to get a regional daily forecast
        summary_stats = metric_data.groupby('date')['value'].agg(
            median='median',
            q5=lambda x: x.quantile(0.05),
            q95=lambda x: x.quantile(0.95)
        ).reset_index()

        print(f"  📊 Summary stats shape: {summary_stats.shape}")
        if summary_stats.empty or len(summary_stats) < 2:
            print(f"  ❌ Insufficient summary statistics data")
            return {}

        # --- Decompose into Rich, Independent Claims ---
        peak_idx = summary_stats['median'].idxmax()
        peak_median = summary_stats.loc[peak_idx, 'median']
        peak_q5 = summary_stats.loc[peak_idx, 'q5']
        peak_q95 = summary_stats.loc[peak_idx, 'q95']

        peak_date_ts = pd.to_datetime(summary_stats.loc[peak_idx, 'date'])
        start_date = pd.to_datetime(start_date_str)
        peak_time_weeks = (peak_date_ts - start_date).days / 7

        first_4_weeks_df = summary_stats.head(28)
        initial_trend = "stable"
        growth_rate = 0.0
        if len(first_4_weeks_df) > 1:
            start_val = first_4_weeks_df['median'].iloc[0]
            end_val = first_4_weeks_df['median'].iloc[-1]
            if start_val > 0:
                growth_rate = ((end_val - start_val) / start_val) * 100
            if growth_rate > 50:
                initial_trend = "rapid growth"
            elif growth_rate > 5:
                initial_trend = "slow growth"
            elif growth_rate < -5:
                initial_trend = "declining"

        total_burden = summary_stats['median'].sum()

        outlook = {
            "peak_magnitude": {
                "median_value": float(round(peak_median, 1)),
                "90_percent_ci": [float(round(peak_q5, 1)), float(round(peak_q95, 1))]
            },
            "peak_timing": {
                "peak_date": str(peak_date_ts.date()),
                "weeks_from_start": float(round(peak_time_weeks, 1))
            },
            "initial_trend": {
                "qualitative_trend": initial_trend,
                "four_week_growth_percentage": float(round(growth_rate, 1))
            },
            "total_burden": {
                "cumulative_median_value": float(round(total_burden, 1))
            }
        }
        return outlook
    except Exception as e:
        print(f"  ❌ Error extracting simulation summary: {e}")
        import traceback
        traceback.print_exc()
        return {}


def derive_tool_call_parameters(model, open_question):
    """
    Uses the LLM to parse an open-ended question and extract all required
    parameters for the GLEAM-AI influenza forecasting simulator.
    """
    prompt = f"""
    You are an expert AI system that translates natural language questions about influenza forecasting into structured parameters for the GLEAM-AI simulation tool.

    **Your Task:**
    Given an open-ended question about influenza seasonal patterns, extract the precise parameters needed to run the GLEAM-AI simulator.

    **Tool Available:** GLEAM-AI influenza forecasting simulator

    **Available Parameters:**
    - `r0_value` (float): Basic reproduction number (range: 1.75 to 3.0)
    - `seasonality_level` (string): "no seasonality", "moderate seasonality", or "high seasonality"
    - `starting_date` (string): Date in YYYY-MM-DD format (September-November for flu season)
    - `prior_immunity_level` (string): Prior immunity percentage (e.g., "20%", "40%")
    - `target_states` (list): US states for analysis
    - `target_metric` (string): Single health metric to analyze (focus on hospitalizations)

    **Available Target Metrics:**
    - "hospital incidence": New hospitalizations per time period
    - "hospital prevalence": Current hospitalizations
    - "latent incidence": New infections

    **Tool Parameters:**
    1. `r0_value` (float): Disease transmissibility (1.75-3.0)
    2. `seasonality_level` (string): Seasonality pattern
    3. `starting_date` (string): Simulation start date (YYYY-MM-DD)
    4. `prior_immunity_level` (string): Prior immunity percentage
    5. `target_states` (list): States for analysis (2-5 states)
    6. `target_metric` (string): Single health metric to analyze

    **Output Format:** Single JSON object with these exact keys.

    **Examples:**

    **Example 1: Comprehensive Epidemiological Forecast**
    Question: "Considering an upcoming influenza season in New York, South Carolina, Minnesota, and Maryland, beginning on 2022-10-01, with an estimated R0 of 1.85, high seasonality, and an initial population immunity of 40%, what is the comprehensive epidemiological forecast for hospital incidence?"

    Output:
    {{
        "r0_value": 1.85,
        "seasonality_level": "high seasonality",
        "starting_date": "2022-10-01",
        "prior_immunity_level": "40%",
        "target_states": ["New York", "South Carolina", "Minnesota", "Maryland"],
        "target_metric": "hospital incidence"
    }}

    **Example 2: Seasonal Trajectory Analysis**
    Question: "For the upcoming influenza season in Missouri, Washington, Ohio, and Georgia, commencing around October 28th, what is the comprehensive epidemiological forecast for hospital prevalence, considering an R0 of 2.1, an absence of seasonality, and an initial 30% prior immunity in the population?"

    Output:
    {{
        "r0_value": 2.1,
        "seasonality_level": "no seasonality",
        "starting_date": "2022-10-28",
        "prior_immunity_level": "30%",
        "target_states": ["Missouri", "Washington", "Ohio", "Georgia"],
        "target_metric": "hospital prevalence"
    }}

    **Example 3: Epidemiological Dynamics Assessment**
    Question: "Considering a scenario for the influenza season in California, Washington, Illinois beginning 2022-09-20, characterized by an R0 of 2.1, high seasonality, and 60% prior immunity, what are the expected overall epidemiological dynamics for hospital incidence?"

    Output:
    {{
        "r0_value": 2.1,
        "seasonality_level": "high seasonality",
        "starting_date": "2022-09-20",
        "prior_immunity_level": "60%",
        "target_states": ["California", "Washington", "Illinois"],
        "target_metric": "hospital incidence"
    }}

    **Question to Parse:**
    "{open_question}"

    Extract the parameters following the exact format shown in the examples above. If specific values are not mentioned, use reasonable defaults within the specified ranges.
    """

    retries = 3
    for attempt in range(retries):
        try:
            response = model.generate_content(prompt)
            cleaned_response = response.text.strip().replace(
                "```json", "").replace("```", "").strip()
            params = json.loads(cleaned_response)

            # Check basic required keys
            basic_keys = ['r0_value', 'seasonality_level', 'starting_date',
                          'prior_immunity_level', 'target_states', 'target_metric']
            if not all(k in params for k in basic_keys):
                print(f"  ⚠️  Missing required keys: {[k for k in basic_keys if k not in params]}")
                continue

            # Validate parameters
            if not (1.75 <= params['r0_value'] <= 3.0):
                print(f"  ⚠️  Invalid R0 value: {params['r0_value']} (must be 1.75-3.0)")
                continue

            if params['seasonality_level'] not in SEASONALITY_LEVELS:
                print(f"  ⚠️  Invalid seasonality: {params['seasonality_level']}")
                continue

            # Validate target metric
            if params['target_metric'] not in TARGET_METRICS_MAPPING:
                print(f"  ⚠️  Invalid target metric: {params['target_metric']}")
                continue

            # Validate states
            if not params['target_states'] or len(params['target_states']) == 0:
                print(f"  ⚠️  No target states specified")
                continue

            # Validate date format
            try:
                datetime.strptime(params['starting_date'], '%Y-%m-%d')
            except ValueError:
                print(f"  ⚠️  Invalid date format: {params['starting_date']}")
                continue

            print(f"  ✅ Successfully validated parameters")
            return params

        except json.JSONDecodeError as e:
            print(f"  ⚠️  JSON parsing failed on attempt {attempt + 1}: {e}")
        except Exception as e:
            print(f"  ⚠️  Parameter derivation failed on attempt {attempt + 1}: {e}")

        if attempt < retries - 1:
            time.sleep(10)

    print("  ❌ Could not derive valid parameters after multiple retries.")
    return None


def execute_tool_and_get_answer(params, model):
    """
    Calls the GLEAM-AI simulator with extracted parameters and formats results.
    """
    try:
        print(f"\n--- Calling GLEAM-AI Simulator ---")
        print(f"  R0: {params['r0_value']}")
        print(
            f"  Seasonality: {params['seasonality_level']} ({SEASONALITY_MAPPING[params['seasonality_level']]})")
        print(f"  Starting date: {params['starting_date']}")
        print(f"  Prior immunity: {params['prior_immunity_level']}")
        print(
            f"  States ({len(params['target_states'])}): {params['target_states']}")
        print(f"  Target metric: {params['target_metric']}")

        # Convert seasonality to numerical value
        seasonality_min = SEASONALITY_MAPPING[params['seasonality_level']]

        # Parse prior immunity level
        prior_immunity_str = params['prior_immunity_level'].rstrip('%')
        prior_immunity = float(prior_immunity_str) / 100.0

        # Calculate frac values based on prior immunity (matching create_open.py)
        frac_latent = 0.00014
        frac_recovered = prior_immunity
        frac_susceptible = 1.0 - frac_latent * 2 - frac_recovered

        # Run GLEAM-AI simulation - need to be in correct directory
        original_dir = os.getcwd()
        os.chdir(
            '/home/ubuntu/Adapting-While-Learning/src/Epidemiology2/gleam-ai-shared')

        # Run GLEAM-AI simulation
        simulation_params = {
            'starting_date': params['starting_date'],
            'R0': params['r0_value'],
            'seasonality_min': seasonality_min,
            'frac_susceptible': frac_susceptible,
            'frac_latent': frac_latent,
            'frac_recovered': frac_recovered,
            'num_runs': 100  # Same as create_open.py for consistency
        }

        print(f"  🔍 Running GLEAM-AI simulation...")
        # Set random seed for reproducible results
        set_seed(42)
        df_sims = model.run(**simulation_params)

        # Change back to original directory
        os.chdir(original_dir)

        if df_sims is None or df_sims.empty:
            print(f"  ❌ Failed to get simulation results")
            return None

        print(f"  ✅ Successfully completed simulation")
        print(f"     Simulation data shape: {df_sims.shape}")

        # Extract structured outlook from simulation (matching create_open.py)
        simulation_outlook = extract_simulation_summary(
            df_sims, params['target_metric'], params['target_states'], params['starting_date']
        )

        if not simulation_outlook:
            print(f"  ❌ Failed to extract simulation summary")
            return None

        print(f"  📊 Extracted structured simulation outlook")

        # Format the answer with structured outlook (matching create_open.py)
        answer_parts = []
        answer_parts.append(
            f"**Epidemiological Outlook for {params['target_metric']}:**")
        answer_parts.append(json.dumps(simulation_outlook, indent=2))

        return "\n".join(answer_parts)

    except Exception as e:
        print(f"Error during simulation: {e}")
        traceback.print_exc()
        return None


def format_quantitative_question(params):
    """
    Programmatically formats the quantitative question string from the derived parameters.
    """
    states_str = ", ".join(params['target_states'])

    return (
        f"Considering an influenza season in {states_str} starting {params['starting_date']}, "
        f"with an R0 of {params['r0_value']}, {params['seasonality_level']}, and "
        f"{params['prior_immunity_level']} prior immunity, what is the anticipated epidemiological "
        f"progression for {params['target_metric']}? Please provide specific numerical results including "
        f"peak timing, magnitude, growth rates, and seasonal patterns."
    )


def process_single_question_with_retry(gemini_model, model, open_q, item, max_retries=3):
    """
    Process a single question with retry mechanism when simulation fails.
    """
    for retry_attempt in range(max_retries):
        print(f"\n=== Attempt {retry_attempt + 1}/{max_retries} ===")

        # Step 1: Use LLM to derive tool call parameters
        print("Step 1: Deriving tool parameters from open question...")
        derived_params = derive_tool_call_parameters(gemini_model, open_q)

        if not derived_params:
            print("-> Failed to derive parameters.")
            if retry_attempt < max_retries - 1:
                print("-> Retrying with new LLM query...")
                time.sleep(10)  # Wait before retry
                continue
            else:
                return None

        print(f"-> Success. Derived Params:")
        print(f"   R0: {derived_params['r0_value']}")
        print(f"   Seasonality: {derived_params['seasonality_level']}")
        print(f"   Starting date: {derived_params['starting_date']}")
        print(f"   Prior immunity: {derived_params['prior_immunity_level']}")
        print(f"   States: {derived_params['target_states']}")
        print(f"   Target Metric: {derived_params['target_metric']}")

        # Step 2: Execute the GLEAM-AI tool
        print("Step 2: Running GLEAM-AI simulation...")
        quantitative_answer = execute_tool_and_get_answer(derived_params, model)

        if quantitative_answer:
            print("-> Success. Quantitative analysis completed.")

            # Step 3: Format the output
            print("Step 3: Formatting final output...")
            quantitative_question = format_quantitative_question(derived_params)

            final_record = {
                "open_question": open_q,
                "original_open_answer": item['open_answer'],
                "derived_quantitative_question": quantitative_question,
                "derived_quantitative_answer": quantitative_answer,
                "derived_parameters": derived_params,
                "retry_attempts": retry_attempt + 1
            }
            return final_record
        else:
            print(f"-> Failed to get quantitative answer from GLEAM-AI.")
            if retry_attempt < max_retries - 1:
                print("-> Retrying with new parameters...")
                time.sleep(15)  # Wait before retry
            else:
                print("-> All retry attempts exhausted. Skipping this question.")
                return None

    return None


def main():
    """
    Main function to process open questions and generate quantitative analysis using GLEAM-AI.
    """
    # Initialize the GLEAM-AI model
    model = initialize_gleam_model()
    if model is None:
        print("FATAL: Could not initialize GLEAM-AI model.")
        return

    # Load API key and configure Gemini
    dotenv_path = os.path.join(os.path.dirname(__file__), '..', '..', '.env')
    load_dotenv(dotenv_path=dotenv_path)
    api_key = os.getenv("Google_API_KEY")
    if not api_key:
        raise ValueError("Google_API_KEY not found in .env file.")

    genai.configure(api_key=api_key)
    gemini_model = genai.GenerativeModel('gemini-2.5-flash')

    # Load open questions
    try:
        with open("open_questions.json", 'r', encoding='utf-8') as f:
            open_questions = json.load(f)
    except FileNotFoundError:
        print("Error: open_questions.json not found. Please run upsample_open.py first.")
        return

    print(f"Processing {len(open_questions)} open questions...")

    final_data = []
    successful_count = 0
    failed_count = 0

    for i, item in enumerate(open_questions):
        print(f"\n--- Processing open question {i+1}/{len(open_questions)} ---")
        open_q = item['open_question']

        # Process with retry mechanism
        result = process_single_question_with_retry(gemini_model, model, open_q, item, max_retries=3)

        if result:
            final_data.append(result)
            successful_count += 1
            print(f"✅ Question {i+1} processed successfully (attempts: {result['retry_attempts']})")
        else:
            failed_count += 1
            print(f"❌ Question {i+1} failed after all retry attempts")

    # Save results
    output_file = "derived_quantitative_questions.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(final_data, f, ensure_ascii=False, indent=4)

    print(f"\n🎯 Processing Summary:")
    print(f"   ✅ Successfully processed: {successful_count}/{len(open_questions)} questions")
    print(f"   ❌ Failed: {failed_count}/{len(open_questions)} questions")
    print(f"   📄 Output saved to {output_file}")


if __name__ == "__main__":
    main()
