import sys
import os
import logging
from datetime import datetime
from tqdm import tqdm
from pathlib import Path
import argparse
import json
import multiprocessing as mp

from organisation.env.clinical_trial.core.logging_setup import setup_sim_logging
from organisation.env.main_sim import OrganisationEnv, ClinicalTrialSimulation
from organisation.env.config import (
    OUTPUT_BASE_DIR,
    NUM_TRAINING_EPISODES,
    POLICY,
    SEED,
)
from organisation.env import config
from organisation.env.clinical_trial.policies import initialise_policy
from organisation.env.clinical_trial.core.viz import save_timeline_from_episode_logs

# ─────────────────────────────────────────────────────────────────────────────
# 0) Prepare output directory for this run
# ─────────────────────────────────────────────────────────────────────────────
run_ts = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = Path(OUTPUT_BASE_DIR) / run_ts
os.makedirs(output_dir, exist_ok=True)

# ─────────────────────────────────────────────────────────────────────────────
# 1) Instantiate the environment *before* wiring up sim_time in logs
# ─────────────────────────────────────────────────────────────────────────────
env = OrganisationEnv(ClinicalTrialSimulation)

# —————————————————————————————————————————————————
# 2) Set up the logging (console + file) with sim_time injection
# —————————————————————————————————————————————————
setup_sim_logging(env, output_dir)

# ─────────────────────────────────────────────────────────────────────────────
# 3) Reconfigure the root logger to inject the current sim time
# ─────────────────────────────────────────────────────────────────────────────

root = logging.getLogger()
# clear any existing handlers
for h in list(root.handlers):
    root.removeHandler(h)

formatter = logging.Formatter(
    fmt="[sim={sim_time:.1f}] {name} {levelname}: {message}",
    style="{",
)


def inject_sim_time(record):
    # grab the current simulation time
    record.sim_time = env.simulation.env.now
    return True


console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
console_handler.addFilter(inject_sim_time)
root.addHandler(console_handler)

file_handler = logging.FileHandler(output_dir / "simulation.log", mode="a")
file_handler.setFormatter(formatter)
file_handler.addFilter(inject_sim_time)
root.addHandler(file_handler)

# Log full config
with open(output_dir / "config.json", "w", encoding="utf-8") as f:
    config_dict = {
        k: v
        for k, v in config.__dict__.items()
        if isinstance(v, (str, int, float, bool))
    }

    json.dump(config_dict, f, indent=2, ensure_ascii=False)

root.setLevel(logging.INFO)

# then get your module‐level logger as usual
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

num_episodes = NUM_TRAINING_EPISODES
log_action_obs = []

# ——————————————————————————————
# 5) policy
# ——————————————————————————————

policy = initialise_policy(POLICY)


def run_episode(ep):
    # blue, bold episode banner
    logger.info("\033[1;34m=== Starting episode %d/%d ===\033[0m", ep + 1, num_episodes)

    obs, _ = env.reset(seed=SEED + ep)
    policy.reset()
    done = False

    events_logger = logging.getLogger("monitoring.events")
    # remove/close any previous handlers (e.g., the default events.log handler)
    for h in list(events_logger.handlers):
        events_logger.removeHandler(h)
        try:
            h.close()
        except Exception:
            pass

    # reuse the formatter and inject_sim_time you already defined above
    events_path = output_dir / f"events_ep{ep + 1:02d}.log"
    events_file = logging.FileHandler(events_path, mode="w")
    events_file.setFormatter(formatter)
    events_file.addFilter(inject_sim_time)
    events_logger.addHandler(events_file)

    env.wait_time = 0

    while not done:
        # No plan: take normal RL action
        action = policy.select_action(obs, env)
        obs, _, done, _, info = env.step(action)

    # end‐of‐episode summary
    logger.info(
        f"Episode {ep + 1:2d} done. total messages={info['metrics']['messages_sent']:3d}, sim_time={info['metrics']['sim_time']:.1f}",
    )

    # Save Metrics
    metrics = info["monitoring"]["summary"]
    with open(
        output_dir / f"monitoring_ep{ep + 1:02d}.json", "w", encoding="utf-8"
    ) as f:
        json.dump(
            {
                "episode": ep + 1,  # episode number (1-based)
                "metrics": metrics,  # A dict of metrics we collect
            },
            f,
            indent=2,
            ensure_ascii=False,
        )

    # Save log of actions and observations
    with open(
        output_dir / f"action_obs_ep{ep + 1:02d}.txt", "w", encoding="utf-8"
    ) as f:
        for line in policy.log_action_obs:
            f.write(str(line))
            f.write("\n\n")

    # Build timeline PNG from the per-episode events log
    try:
        save_timeline_from_episode_logs(output_dir, ep + 1, env, info)
    except Exception:
        logger.exception("Failed to render timeline for episode %d", ep + 1)


# ——————————————————————————————
# 6) Main running loop
# ——————————————————————————————
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--parallel", action="store_true", help="Run episodes in parallel"
    )
    parser.add_argument(
        "--incentives",
        action="store_true",
        help="Require incentives to be enabled in config",
    )
    args = parser.parse_args()

    # If the flag is passed but config says incentives are OFF, abort loudly.
    if args.incentives and not getattr(config, "INCENTIVES_ENABLED", False):
        raise SystemExit(
            "ERROR: --incentives was requested, but INCENTIVES_ENABLED is False in "
            "organisation.env.clinical_trial.core.config.\n"
            "Please set INCENTIVES_ENABLED = True there (and optionally set INCENTIVES_FILE=path/to/file.json) "
            "before running again."
        )

    if args.incentives:
        logging.getLogger("incentives").info(
            "Incentives requested and enabled (config.INCENTIVES_ENABLED=True). "
            "INCENTIVES_FILE=%s",
            os.getenv("INCENTIVES_FILE", "<not set>"),
        )

    print(
        f"Running {num_episodes} episodes with policy {POLICY} "
        f"with parallel={args.parallel}; incentives={args.incentives}; "
        f"INCENTIVES_FILE={os.getenv('INCENTIVES_FILE', '<not set>')}"
    )

    if args.parallel:
        with mp.Pool(processes=min(num_episodes, 64)) as p:
            for _ in tqdm(
                p.imap_unordered(run_episode, range(num_episodes)), total=num_episodes
            ):
                pass
    else:
        for ep in tqdm(range(num_episodes)):
            run_episode(ep)
