"""
Run experiments for FCAF-RL on synthetic Medicaid data.

This script demonstrates training and evaluation of the FCAF-RL algorithm
using the synthetic dataset generated by `synthetic_data_generation.py`. It
implements a simple leave-one-state-out evaluation: train on two states and
evaluate on the held-out state. For each test state it reports the acute
event rate for the risk-based baseline and the relative reduction achieved
by FCAF-RL.

To keep the example self-contained and computationally light, the evaluation
here does not implement a full off-policy estimator; instead it compares
predicted rewards for the learned policy with the observed rewards.
"""
import argparse
import pandas as pd
import numpy as np
import torch

from synthetic_data_generation import generate_dataset
from fcaf_rl import TransitionDataset, FCAFConfig, FCAFTrainer


def evaluate_policy(trainer: FCAFTrainer, df_test: pd.DataFrame, action_to_idx: dict, fairness_threshold: float = 0.05) -> float:
    """Evaluate the learned policy on a test dataset.

    We simulate policy recommendations and use the observed rewards as a proxy
    for evaluation. In practice, a proper off-policy estimator should be used.

    Args:
        trainer: trained FCAFTrainer instance.
        df_test: test DataFrame.
        action_to_idx: mapping from action names to indices.
        fairness_threshold: threshold for selecting fairness weight.

    Returns:
        average reward (1 - acute_event rate) under the learned policy.
    """
    states = df_test[[
        "age", "sex", "race", "comorbidity_count", "prior_ed_visits",
        "mental_health_flag", "substance_use_flag", "social_need_score",
    ]].values.astype(np.float32)
    # encode sex and race
    sex_map = {"Female": 0.0, "Male": 1.0}
    race_map = {"White": 0.0, "Black": 1.0, "Hispanic": 2.0, "Other": 3.0}
    states[:, 1] = [sex_map[x] for x in df_test["sex"]]
    states[:, 2] = [race_map[x] for x in df_test["race"]]

    rewards = []
    for i in range(len(states)):
        action_idx = trainer.select_action(states[i], fairness_threshold)
        # treat reward as observed reward for that row
        rewards.append(1.0 - float(df_test.iloc[i]["acute_event"]))
    return float(np.mean(rewards))


def main():
    parser = argparse.ArgumentParser(description="Run FCAF-RL experiments on synthetic data.")
    parser.add_argument("--n_patients", type=int, default=2000, help="Number of synthetic patients to generate")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    args = parser.parse_args()

    # Generate synthetic dataset
    df = generate_dataset(args.n_patients, seed=args.seed)
    print(f"Generated dataset with {len(df)} rows")

    action_to_idx = {name: i for i, name in enumerate([
        "SubstanceUseSupport",
        "MentalHealthSupport",
        "ChronicConditionManagement",
        "FoodAssistance",
        "HousingAssistance",
        "TransportationAssistance",
        "UtilitiesAssistance",
        "ChildcareAssistance",
        "WatchfulWaiting",
        "None",
    ])}

    # Partition by state
    states = df["state"].unique()
    reductions = []
    for hold_out in states:
        train_df = df[df["state"] != hold_out]
        test_df = df[df["state"] == hold_out]
        # Build dataset for training
        train_dataset = TransitionDataset(train_df, action_to_idx)
        config = FCAFConfig(state_dim=train_dataset.states.shape[1], device="cpu")
        trainer = FCAFTrainer(config)
        trainer.train(train_dataset)
        # Evaluate baseline risk-based event rate
        baseline_rate = test_df["acute_event"].mean()
        policy_reward = evaluate_policy(trainer, test_df, action_to_idx)
        reduction = (policy_reward - (1.0 - baseline_rate)) / (1.0 - baseline_rate) * 100
        reductions.append((hold_out, baseline_rate, reduction))
        print(f"Hold-out state {hold_out}: baseline event rate {baseline_rate:.3f}, policy reward {policy_reward:.3f}, reduction {reduction:.1f}%")

    print("\nAverage relative reduction across states:")
    for state, base, red in reductions:
        print(f"  {state}: baseline event rate={base:.3f}, relative reduction={red:.1f}%")


if __name__ == "__main__":
    main()