#!/usr/bin/env python
"""
=======================================================================
  BenchMarl + VMAS runner with SmartGridScenario
  ------------------------------------------------
  • TRAIN  : python run.py --algorithm mappo --task customenv
  • EVAL   : python run.py --algorithm mappo --task customenv \
                           --mode eval \
                           --checkpoint path/to/ckpt_last.ckpt
     ↳ BenchMarl loads the checkpoint and runs its built-in evaluator.
       Your SmartGridScenario.step_end() writes per-step actions to CSV.
=======================================================================
"""

import argparse
import copy
import os
from pathlib import Path
from typing import Optional

from omegaconf import OmegaConf

from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.algorithms import MappoConfig, MaddpgConfig, MasacConfig, IddpgConfig, IsacConfig
from benchmarl.environments import VmasTask
from benchmarl.models.mlp import MlpConfig
from torchrl.envs import VmasEnv

from smart_grid import SmartGridScenario

# ---------------------------------------------------------------------
# 1.  Monkey-patch get_env_fun so NAVIGATION → SmartGridScenario
# ---------------------------------------------------------------------
EVAL_MODE = False  # will be flipped by CLI


def get_env_fun(
    self,
    num_envs: int,
    continuous_actions: bool,
    seed: Optional[int],
    device: str,
):
    cfg = copy.deepcopy(self.config)

    if self.name.lower() == "navigation":
        scenario_obj = SmartGridScenario(eval=EVAL_MODE)
    else:
        scenario_obj = self.name.lower()

    return lambda: VmasEnv(
        scenario=scenario_obj,
        continuous_actions=continuous_actions,
        seed=seed,
        device=device,
        categorical_actions=False,
        clamp_actions=True,
        **cfg,
    )


try:
    from benchmarl.environments import VmasClass

    VmasClass.get_env_fun = get_env_fun
except ImportError:  # very old BenchMarl
    VmasTask.get_env_fun = get_env_fun

# ---------------------------------------------------------------------
# 2.  Config helpers
# ---------------------------------------------------------------------


def load_experiment_cfg() -> ExperimentConfig:
    cfg = ExperimentConfig.get_from_yaml()
    local = Path("configs") / "experiment.yaml"
    if local.is_file():
        override = OmegaConf.load(local)
        for k, v in override.items():
            if hasattr(cfg, k):
                setattr(cfg, k, v)
    return cfg


def load_task_cfg(name: str) -> VmasTask:
    if name.lower() == "customenv":
        task = VmasTask.NAVIGATION.get_from_yaml()
    else:
        try:
            task = getattr(VmasTask, name.upper()).get_from_yaml()
        except AttributeError:
            print(f"[WARNING] Unknown task '{name}', defaulting to NAVIGATION.")
            task = VmasTask.NAVIGATION.get_from_yaml()

    local = Path("configs") / "task.yaml"
    if local.is_file():
        override = OmegaConf.load(local)
        task.config = dict(override)
    return task


def load_algo_cfg(name: str):
    name = name.lower()
    return {
        "mappo": MappoConfig,
        "maddpg": MaddpgConfig,
        "masac": MasacConfig,
        "iddpg": IddpgConfig,
        "isac": IsacConfig,
    }.get(name, MappoConfig).get_from_yaml()


# ---------------------------------------------------------------------
# 3.  Main
# ---------------------------------------------------------------------
def main():
    global EVAL_MODE

    cli = argparse.ArgumentParser()
    cli.add_argument("--algorithm", default="mappo")
    cli.add_argument("--task", default="customenv")
    cli.add_argument("--mode", choices=["train", "eval"], default="train")
    cli.add_argument(
        "--checkpoint",
        help="Path to BenchMarl .ckpt to restore (required for --mode eval)",
    )
    cli.add_argument("--building-types", type=str, default=None,
                 help="Comma‑separated list for SmartGrid building types.")
    cli.add_argument(
    "--episode-length",
    type=int,
    default=None,
    help="Override episode length set in task.yaml (eval or train).",
)
    cli.add_argument(
    "--seed",
    type=int,
    default=42,
    help="Random seed to pass to BenchMarl Experiment (default 42).",
)
    args = cli.parse_args()
    EVAL_MODE = args.mode == "eval"

    # --- load configs ------------------------------------------------
    exp_cfg = load_experiment_cfg()
    task = load_task_cfg(args.task)
    
    # If the user passed a custom length, inject it into the Task config
    if args.episode_length is not None:
        task.config["episode_length"] = args.episode_length
        task.config["max_steps"] = args.episode_length

    # If the user passed a custom building types, inject it into the Task config
    if args.building_types:
        task.config["building_types"] = list(map(int, args.building_types.split(',')))

    algo_cfg = load_algo_cfg(args.algorithm)
    model_cfg = MlpConfig.get_from_yaml()
    critic_cfg = MlpConfig.get_from_yaml()

    # --- evaluation-only tweaks -------------------------------------
    if EVAL_MODE:
        if not args.checkpoint:
            cli.error("--mode eval needs --checkpoint FILE")
        exp_cfg.restore_file = os.path.abspath(args.checkpoint)
        exp_cfg.max_n_frames = 0            # skip training completely
        # evaluation_interval may stay whatever it is; with max_n_frames=0
        # the trainer never starts, and Experiment.evaluate() is called below.

    else: # ─── auto-increment wrapper folder ─────────────────────────────
        base = Path(".")
        algo = args.algorithm.lower()
        existing = []
        for d in base.iterdir():
            if not d.is_dir(): 
                continue
            if d.name.startswith(f"{algo}_"):
                suffix = d.name[len(algo) + 1 :]
                if suffix.isdigit():
                    existing.append(int(suffix))

        next_n = max(existing) + 1 if existing else 1
        wrapper = base / f"{algo}_{next_n}"
        wrapper.mkdir(parents=True, exist_ok=True)
        exp_cfg.save_folder = str(wrapper)


    # --- build experiment -------------------------------------------
    experiment = Experiment(
        task=task,
        algorithm_config=algo_cfg,
        model_config=model_cfg,
        critic_model_config=critic_cfg,
        seed=args.seed,
        config=exp_cfg,
    )

    # --- run mode ----------------------------------------------------
    if not EVAL_MODE:
        experiment.run()            # normal training run
        return

    # -------- evaluation-only run -----------------------------------
    # BenchMarl’s own evaluator; respects evaluation_episodes, render, etc.
    os.environ["BM_ALGO"] = args.algorithm.lower()
    experiment.evaluate()
    print("[✓] Evaluation done — see CSV logs produced by SmartGridScenario.")


if __name__ == "__main__":
    main()
