from sacred import Experiment
from reward_modeling.reward_wrapper_pbrr import RewardModel
from learn_reward.train_policy import ex
import learn_reward.unique_id_state as unique_id_state
import warnings

warnings.filterwarnings("ignore")

iterative_ex = Experiment("PBRR", save_git_info=False)


@iterative_ex.config
def config():
    # Core experiment settings
    env_to_run = "pandemic"
    level = 4
    reward_fun = "proxy"
    exp_algo = "ORPO"
    seed = 0
    checkpoint_to_load_current_policy = None
    # ORPO / rollout settings
    num_rollout_workers = 10
    num_gpus = 1
    experiment_parts = [env_to_run]

    # Training iters for the over-optimization run
    num_training_iters = 260

    om_divergence_coeffs = ["0.0"]
    checkpoint_to_load_policies = None
    experiment_tag = "state"
    om_divergence_type = ["kl"]

    move_ref_policy=False


@iterative_ex.automain
def main(
    env_to_run,
    level,
    reward_fun,
    exp_algo,
    checkpoint_to_load_current_policy,
    seed,
    num_rollout_workers,
    num_gpus,
    experiment_parts,
    num_training_iters,
    om_divergence_coeffs,
    checkpoint_to_load_policies,
    experiment_tag,
    om_divergence_type,
    move_ref_policy,
    _log,
):
    """
    Run PBRR loop:
      1) Train an over-optimized policy on the current (proxy or repaired) reward.
      2) Collect rollouts from a "reference" (no divergence pressure) policy.
      3) Update the reward model from trajectory pairs.
      4) Repeat.
    """

    # Per-environment RewardModel settings (cannot infer from gym without init)
    rm_configs = {
        "pandemic": dict(
            obs_dim=2 * 24 * 13,
            action_dim=3,
            sequence_lens=193,
            discrete_actions=True,
            env_name="pandemic_sas",
            n_epochs=200,
            n_layers=5,
            layer_size=512,
            use_weight_decay=True
        ),
        "glucose": dict(
            obs_dim=2 * 48 * 2,
            action_dim=1,
            sequence_lens=5760,
            discrete_actions=False,
            env_name="glucose_sas",
            n_epochs=200,
            n_layers=5,
            layer_size=512,
            use_weight_decay=False
        ),
        "traffic": dict(
            obs_dim=2 * 50,
            action_dim=10,
            sequence_lens=4000,
            discrete_actions=False,
            env_name="traffic_sas",
            n_epochs=50,
            n_layers=3,
            layer_size=256,
            use_weight_decay=True
        ),
        "tomato": dict(
            obs_dim=2 * 36,  # assumed obs size 36
            action_dim=4,
            sequence_lens=100,
            discrete_actions=True,
            env_name="tomato",
            n_epochs=200,
            n_layers=5,
            layer_size=512,
            use_weight_decay=True
        ),
    }

    # Match config by substring in env_to_run (e.g., "pandemic_level4")
    selected_key = next((k for k in rm_configs if k in env_to_run), None)
    if selected_key is None:
        raise ValueError(f"Unknown environment in env_to_run='{env_to_run}'. "
                         f"Expected one of: {list(rm_configs.keys())}")

    reward_model = RewardModel(
        unique_id=unique_id_state.state["unique_id"],
        lr=0.0001,
        **rm_configs[selected_key],
    )

    # Initialize and checkpoint the zeroed reward model
    reward_model.zero_model_params()
    reward_model.save_params()

    # 2) Reference policy (no divergence pressure), collect rollouts
    reference_result = ex.run(
        config_updates={
            "env_to_run": env_to_run,
            "level": level,
            "reward_fun": reward_fun,
            "exp_algo": exp_algo,
            "om_divergence_coeffs": ["0.0"],
            "checkpoint_to_load_policies": None,
            "checkpoint_to_load_current_policy": checkpoint_to_load_policies[0],
            "seed": seed,
            "num_rollout_workers": num_rollout_workers,
            "num_gpus": num_gpus,
            "experiment_parts": experiment_parts,
            "num_training_iters": 0,  # collect only
            
        }
    )
    eval_batch_reference = reference_result.result[2]

    for i in range(10):
        # 1) Over-optimized policy wrt current reward
        over_opt_result = ex.run(
            config_updates={
                "env_to_run": env_to_run,
                "level": level,
                "reward_fun": reward_fun,
                "exp_algo": exp_algo,
                "checkpoint_to_load_policies": checkpoint_to_load_policies,
                "checkpoint_to_load_current_policy": checkpoint_to_load_current_policy,
                "seed": seed,
                "num_rollout_workers": num_rollout_workers,
                "num_gpus": num_gpus,
                "experiment_parts": experiment_parts,
                "num_training_iters": num_training_iters,
                "om_divergence_coeffs": om_divergence_coeffs,
                "experiment_tag": experiment_tag,
                "om_divergence_type": om_divergence_type,
            }
        )
        eval_batch_over_opt = over_opt_result.result[2]

        if move_ref_policy:
            # 2) Reference policy (no divergence pressure), collect rollouts
            reference_result = ex.run(
                config_updates={
                    "env_to_run": env_to_run,
                    "level": level,
                    "reward_fun": reward_fun,
                    "exp_algo": exp_algo,
                    "om_divergence_coeffs": ["0.0"],
                    "checkpoint_to_load_policies": None,
                    "checkpoint_to_load_current_policy": checkpoint_to_load_policies[0],
                    "seed": seed + i,
                    "num_rollout_workers": num_rollout_workers,
                    "num_gpus": num_gpus,
                    "experiment_parts": experiment_parts,
                    "num_training_iters": 0,  # collect only
                    
                }
            )
            eval_batch_reference = reference_result.result[2]

        # Align batch sizes
        over_opt_rollouts = eval_batch_over_opt["current"]
        reference_rollouts = eval_batch_reference["current"][: len(over_opt_rollouts)]

        # 3) Update reward model from (over-optimized vs reference) comparisons
        reward_model.update_params(
            over_opt_rollouts,
            reference_rollouts,
            iteration=i,
            use_minibatch=True,
            use_regularization=False,
        )

        if move_ref_policy:
            checkpoint_to_load_policies = ["/next/u/stephhk/orpo/"+over_opt_result.result[1]]
            
