import os
import json
import random
import sys
import numpy as np
import pandas as pd
from datetime import datetime, timedelta

# Add the gleam-ai-shared directory to path
gleam_ai_path = os.path.join(os.path.dirname(
    os.path.abspath(__file__)), 'gleam-ai-shared')
sys.path.append(gleam_ai_path)

# Import GLEAM-AI utilities
import gleam_ai
from gleam_ai.utils import set_seed

# --- Constants and Mappings ---

SEASONALITY_MAPPING = {'no seasonality': 1.0, 'moderate seasonality': 0.75, 'high seasonality': 0.5}
R0_VALUES = [1.75, 1.85, 2.0, 2.1, 2.2, 2.4, 2.6, 2.8, 3.0]
SEASONALITY_LEVELS = list(SEASONALITY_MAPPING.keys())
TARGET_METRICS_MAPPING = {'hospital incidence': 'hosp_inc', 'hospital prevalence': 'hosp_prev', 'latent incidence': 'latent_inc'}
US_STATES = ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming', 'District of Columbia']
PRIOR_IMMUNITY_LEVELS = [0.10, 0.20, 0.30, 0.40, 0.50, 0.60]

class SafeDict(dict):
    def __missing__(self, key): return f'{{{key}}}'

# --- Helper and Simulation Functions ---

def initialize_gleam_model():
    try:
        print("🔧 Initializing GLEAM-AI model...")
        config_path = 'config.yaml'
        model = gleam_ai.load_pretrained_model(config_path=config_path, device='cuda')
        print("✅ GLEAM-AI model initialized successfully!")
        return model
    except Exception as e:
        print(f"❌ Failed to initialize GLEAM-AI model: {e}")
        import traceback
        traceback.print_exc()
        return None

def generate_starting_date():
    start_range = datetime(2022, 9, 18)
    end_range = datetime(2022, 10, 8)
    random_days = random.randint(0, (end_range - start_range).days)
    return (start_range + timedelta(days=random_days)).strftime('%Y-%m-%d')

def extract_simulation_summary(df_sims, target_metric, target_states, start_date_str):
    """
    Extracts a rich, PURELY quantitative outlook from the daily simulation results.
    """
    try:
        compartment = TARGET_METRICS_MAPPING[target_metric]
        state_ids_to_analyze = [US_STATES.index(s) for s in target_states if s in US_STATES]

        metric_data = df_sims[
            (df_sims['compartment'] == compartment) &
            (df_sims['state_id'].isin(state_ids_to_analyze))
        ]
        if metric_data.empty: return {}

        # Aggregate across all selected states to get a regional daily forecast
        summary_stats = metric_data.groupby('date')['value'].agg(
            median='median',
            q5=lambda x: x.quantile(0.05),
            q95=lambda x: x.quantile(0.95)
        ).reset_index()

        # --- Decompose into Rich, Independent Claims ---
        peak_idx = summary_stats['median'].idxmax()
        peak_median = summary_stats.loc[peak_idx, 'median']
        peak_q5 = summary_stats.loc[peak_idx, 'q5']
        peak_q95 = summary_stats.loc[peak_idx, 'q95']
        
        peak_date_ts = pd.to_datetime(summary_stats.loc[peak_idx, 'date'])
        start_date = pd.to_datetime(start_date_str)
        peak_time_weeks = (peak_date_ts - start_date).days / 7

        first_4_weeks_df = summary_stats.head(28)
        initial_trend = "stable"
        growth_rate = 0.0
        if len(first_4_weeks_df) > 1:
            start_val = first_4_weeks_df['median'].iloc[0]
            end_val = first_4_weeks_df['median'].iloc[-1]
            if start_val > 0:
                growth_rate = ((end_val - start_val) / start_val) * 100
            if growth_rate > 50: initial_trend = "rapid growth"
            elif growth_rate > 5: initial_trend = "slow growth"
            elif growth_rate < -5: initial_trend = "declining"

        total_burden = summary_stats['median'].sum()

        outlook = {
            "peak_magnitude": {
                "median_value": float(round(peak_median, 1)),
                "90_percent_ci": [float(round(peak_q5, 1)), float(round(peak_q95, 1))]
            },
            "peak_timing": {
                "peak_date": str(peak_date_ts.date()),
                "weeks_from_start": float(round(peak_time_weeks, 1))
            },
            "initial_trend": {
                "qualitative_trend": initial_trend,
                "four_week_growth_percentage": float(round(growth_rate, 1))
            },
            "total_burden": {
                 "cumulative_median_value": float(round(total_burden, 1))
            }
        }
        return outlook
    except Exception as e:
        print(f"Error extracting simulation summary: {e}")
        import traceback
        traceback.print_exc()
        return {}

def generate_qa_pair(template, model):
    """Generates a standard 'what-if' forecasting question."""
    question_template = template['question_template'].replace('{{', '{').replace('}}', '}')
    answer_template = template['answer_template'].replace('{{', '{').replace('}}', '}')

    params = {}
    params['r0_value'] = random.choice(R0_VALUES)
    params['seasonality_level'] = random.choice(SEASONALITY_LEVELS)
    params['starting_date'] = generate_starting_date()
    
    prior_immunity = random.choice(PRIOR_IMMUNITY_LEVELS)
    params['prior_immunity_level'] = f"{int(prior_immunity * 100)}%"
    
    frac_latent = 0.00014 
    frac_recovered = prior_immunity
    frac_susceptible = 1.0 - frac_latent * 2 - frac_recovered
    
    # Per expert feedback, focus on hospitalization metrics
    params['target_metric'] = random.choice(['hospital incidence', 'hospital prevalence'])
    
    num_states = random.randint(2, 4)
    selected_states = random.sample(US_STATES, num_states)
    params['target_states'] = ', '.join(selected_states)

    try:
        print(f"\n--- Calling GLEAM-AI Simulator for Outlook ---")
        print(f"  R0: {params['r0_value']}, Seasonality: {params['seasonality_level']}")
        print(f"  Start Date: {params['starting_date']}, Prior Immunity: {params['prior_immunity_level']}")
        print(f"  States: {params['target_states']}, Metric: {params['target_metric']}")

        simulation_params = {
            'starting_date': params['starting_date'], 
            'R0': params['r0_value'], 
            'seasonality_min': SEASONALITY_MAPPING[params['seasonality_level']], 
            'frac_susceptible': frac_susceptible, 
            'frac_latent': frac_latent, 
            'frac_recovered': frac_recovered, 
            'num_runs': 100
        }

        set_seed(42)
        df_sims = model.run(**simulation_params)

        if df_sims is None or df_sims.empty:
            print("  ❌ Failed to get simulation results")
            return None
        print(f"  ✅ Successfully completed simulation (shape: {df_sims.shape})")

        simulation_outlook = extract_simulation_summary(df_sims, params['target_metric'], selected_states, params['starting_date'])

        if not simulation_outlook:
            print("  ❌ Failed to extract simulation summary")
            return None

        params['simulation_outlook'] = json.dumps(simulation_outlook, indent=2)
        print("  📊 Extracted rich simulation outlook.")

    except Exception as e:
        print(f"❌ Error during simulation: {e}")
        import traceback
        traceback.print_exc()
        return None

    final_question = question_template.format_map(SafeDict(params))
    final_answer = answer_template.format_map(SafeDict(params))

    return {"question": final_question, "answer": final_answer}

def main():
    """
    Main function to load templates and generate concrete question-answer pairs.
    """
    # Change to gleam-ai-shared directory
    os.chdir('/home/ubuntu/Adapting-While-Learning/src/Epidemiology2/gleam-ai-shared')

    # Initialize GLEAM-AI model
    model = initialize_gleam_model()
    if model is None:
        print("❌ FATAL: Could not initialize GLEAM-AI model.")
        return

    # Load templates
    try:
        template_path = '../topics.json'  # Go back one directory to find topics.json
        with open(template_path, 'r', encoding='utf-8') as f:
            templates = json.load(f)
    except FileNotFoundError:
        print(
            f"Error: topics.json not found at {template_path}. Please run search_topics.py first.")
        return

    if not templates:
        print("Error: topics.json is empty.")
        return

    all_generated_pairs = []
    num_questions_per_template = 10

    for i, template_to_use in enumerate(templates):
        generated_qa_pairs_for_template = []
        print(
            f"\n--- Generating {num_questions_per_template} Q&A pairs for template {i+1} ---")

        attempts = 0
        max_attempts = num_questions_per_template * 3
        while len(generated_qa_pairs_for_template) < num_questions_per_template and attempts < max_attempts:
            qa_pair = generate_qa_pair(template_to_use, model)
            if qa_pair:
                generated_qa_pairs_for_template.append(qa_pair)
                print(
                    f"Generated Q&A pair {len(generated_qa_pairs_for_template)}/{num_questions_per_template} for template {i+1}")
            attempts += 1
        all_generated_pairs.extend(generated_qa_pairs_for_template)

    # Save to parent directory
    output_file = "../generated_questions.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(all_generated_pairs, f, ensure_ascii=False, indent=4)

    print(
        f"\nSuccessfully generated and saved {len(all_generated_pairs)} Q&A pairs to {output_file}")


if __name__ == "__main__":
    main()