from __future__ import annotations

import argparse
import json
import os
from dataclasses import asdict
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional plotting (allow --help without matplotlib)
try:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
except Exception:  # pragma: no cover
    plt = None  # type: ignore

# Optional resize backends
try:
    import cv2  # type: ignore
except Exception:  # pragma: no cover
    cv2 = None  # type: ignore
try:
    from PIL import Image  # type: ignore
except Exception:  # pragma: no cover
    Image = None  # type: ignore

# Ensure project root for absolute imports
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_THIS_DIR)
if _ROOT_DIR not in os.sys.path:
    os.sys.path.insert(0, _ROOT_DIR)

# Optional dependency: minigrid (allow `--help` without installing)
try:
    from visual_gridworld.visual_minigrid import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore
from visual_gridworld.sr_dqn_model import SRDQNNet, SRDQNConfig
from visual_gridworld.sr_dqn_replay import ReplayBuffer


def _rgb_to_gray_u8(img: np.ndarray) -> np.ndarray:
    # img: [H,W,3] uint8
    if img is None:
        raise ValueError("Input image is None")
    if img.ndim == 3 and img.shape[2] >= 3:
        gray = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.uint8)
        return gray
    return img.astype(np.uint8)


def preprocess_obs(obs_rgb: np.ndarray) -> np.ndarray:
    """RGB uint8 -> uint8 [1,H,W] (grayscale + downsample to max side 128)."""
    gray = _rgb_to_gray_u8(obs_rgb)
    h, w = gray.shape[:2]
    max_side = 128
    step = max(1, int(np.ceil(max(h, w) / float(max_side))))
    if step > 1:
        gray = gray[::step, ::step]
    return gray.astype(np.uint8)[None, :, :]


def _save_gray_u8(path_no_ext: str, img_u8: np.ndarray) -> None:
    """Save a single-channel uint8 image to PNG if possible; else save .npy."""
    img2d = img_u8
    if img2d.ndim == 3 and img2d.shape[0] == 1:
        img2d = img2d[0]
    if img2d.ndim != 2:
        raise ValueError(f"Expected [H,W] or [1,H,W] uint8, got shape={img_u8.shape}")

    # Prefer PNG for easy inspection.
    png_path = path_no_ext + ".png"
    if cv2 is not None:
        # cv2.imwrite expects [H,W] uint8 for grayscale.
        _ = cv2.imwrite(png_path, img2d)
        return
    if Image is not None:
        Image.fromarray(img2d).save(png_path)
        return

    # Fallback: save raw array.
    np.save(path_no_ext + ".npy", img2d.astype(np.uint8))


@torch.no_grad()
def intrinsic_reward_from_target_psi(
    target_net: SRDQNNet,
    obs_u8: np.ndarray,
    device: torch.device,
    *,
    beta: float,
    eps: float,
) -> float:
    x = torch.from_numpy(obs_u8).to(device).float().unsqueeze(0) / 255.0  # [1,1,84,84]
    # ψ(s) computed from target φ
    phi = target_net.encode(x)
    psi = target_net.psi(phi.detach())
    l1 = torch.norm(psi, p=1, dim=1).clamp_min(float(eps))
    r_int = float(beta) * float((1.0 / l1).item())
    return r_int


def epsilon_by_step(step: int, eps_start: float, eps_end: float, eps_decay_steps: int) -> float:
    if step >= int(eps_decay_steps):
        return float(eps_end)
    frac = float(step) / float(max(1, int(eps_decay_steps)))
    return float(eps_start + (eps_end - eps_start) * frac)


@torch.no_grad()
def select_action_eps_greedy(
    online: SRDQNNet,
    obs_u8: np.ndarray,
    device: torch.device,
    *,
    eps: float,
    rng: np.random.Generator,
) -> int:
    if rng.random() < float(eps):
        return int(rng.integers(0, online.cfg.num_actions))
    x = torch.from_numpy(obs_u8).to(device).float().unsqueeze(0) / 255.0
    phi = online.encode(x)
    q = online.q(phi)
    return int(torch.argmax(q, dim=1).item())


def main() -> int:
    p = argparse.ArgumentParser(description="SR intrinsic reward exploration (DQN) in visual_gridworld.")
    # env
    p.add_argument("--env_size", type=int, default=20)
    p.add_argument("--seed", type=int, default=1)
    p.add_argument("--episode_len", type=int, default=2000000)
    p.add_argument("--no_render", action="store_true", help="Disable human rendering.")

    # training
    p.add_argument("--total_steps", type=int, default=9000)
    p.add_argument("--learning_starts", type=int, default=100)
    p.add_argument("--replay_capacity", type=int, default=3000)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--train_freq", type=int, default=4)
    p.add_argument("--updates_per_step", type=int, default=1)
    p.add_argument("--target_update_freq", type=int, default=200)
    p.add_argument("--lr", type=float, default=2.5e-4)
    p.add_argument(
        "--rmsprop_eps",
        type=float,
        default=0.000009765625,
        help="RMSprop epsilon (SR reference uses 9.765625e-06).",
    )
    p.add_argument(
        "--rmsprop_centered",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Use centered RMSprop (matches SR reference implementation).",
    )
    p.add_argument(
        "--log_every",
        type=int,
        default=500,
        help="Print a training summary every K env steps (0 disables).",
    )
    p.add_argument(
        "--save_recon_every",
        type=int,
        default=4000,
        help="If >0, save reconstruction (pred vs target) every K env steps into out_dir/recon_samples/.",
    )

    # discounts
    p.add_argument("--gamma_q", type=float, default=0.99)
    p.add_argument("--gamma_sf", type=float, default=0.9)

    # epsilon schedule (in steps)
    p.add_argument("--eps_start", type=float, default=1.0)
    p.add_argument("--eps_end", type=float, default=0.1)
    p.add_argument("--eps_decay_steps", type=int, default=1000)

    # intrinsic reward
    p.add_argument("--beta", type=float, default=0.1)
    p.add_argument("--eps", type=float, default=1e-6, help="Numerical eps in 1/(||psi||_1+eps).")

    # loss weights
    p.add_argument("--w_q", type=float, default=1.0)
    p.add_argument("--w_sr", type=float, default=10.0)
    p.add_argument("--w_recon", type=float, default=1.0)

    # output
    p.add_argument("--out_dir", type=str, default=os.path.join("runs", "sr_intrinsic_vg"))
    # coverage analysis
    p.add_argument(
        "--coverage_reset_interval",
        type=int,
        default=0,
        help="If >0, also compute within-window unique-position coverage (visited reset every K env steps).",
    )
    args = p.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    with open(os.path.join(args.out_dir, "config.json"), "w", encoding="utf-8") as f:
        json.dump(vars(args), f, indent=2)

    rng = np.random.default_rng(int(args.seed))
    torch.manual_seed(int(args.seed))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for visual_gridworld environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run sr_dqn_explore."
        )

    env = SimpleEnv(size=int(args.env_size), render_mode="rgb_array", highlight=False)
    obs_rgb, _ = env.reset(seed=int(args.seed))
    
    if obs_rgb is None:
        raise ValueError(f"Failed to get valid observation from environment. Try different seed or check environment setup.")
    
    obs_u8 = preprocess_obs(obs_rgb)
    in_h, in_w = int(obs_u8.shape[1]), int(obs_u8.shape[2])

    cfg = SRDQNConfig(in_channels=1, num_actions=int(env.action_space.n), phi_dim=1024, input_hw=(in_h, in_w))
    online = SRDQNNet(cfg).to(device)
    target = SRDQNNet(cfg).to(device)
    target.load_state_dict(online.state_dict())
    target.eval()
    for p0 in target.parameters():
        p0.requires_grad = False

    # Use SR-reference RMSprop hyperparams; DQN-default eps=0.01 can make updates too damped.
    optim = torch.optim.RMSprop(
        online.parameters(),
        lr=float(args.lr),
        alpha=0.95,
        eps=float(args.rmsprop_eps),
        centered=bool(args.rmsprop_centered),
        momentum=0.0,
    )
    huber = nn.SmoothL1Loss(reduction="mean")

    replay = ReplayBuffer(int(args.replay_capacity), obs_shape=(1, in_h, in_w))

    recon_dir: str | None = None
    if int(args.save_recon_every) > 0:
        recon_dir = os.path.join(args.out_dir, "recon_samples")
        os.makedirs(recon_dir, exist_ok=True)

    visited_cum = set()
    visited_win = set()
    counts = np.zeros((env.height, env.width), dtype=np.int32)

    # track starting position
    visited_cum.add(tuple(env.agent_pos))
    visited_win.add(tuple(env.agent_pos))
    counts[env.agent_pos[1], env.agent_pos[0]] += 1
    coverage = [len(visited_cum)]
    coverage_windowed = [len(visited_win)]

    ep_step = 0
    # Track latest losses for logging (averaged over last training trigger)
    last_q_loss: float | None = None
    last_sr_loss: float | None = None
    last_recon_loss: float | None = None
    last_total_loss: float | None = None

    for step in range(int(args.total_steps)):
        eps = epsilon_by_step(step, args.eps_start, args.eps_end, args.eps_decay_steps)
        action = select_action_eps_greedy(online, obs_u8, device, eps=eps, rng=rng)

        next_rgb, _r_ext, terminated, truncated, _info = env.step(int(action))
        next_u8 = preprocess_obs(next_rgb)

        # Optional: save reconstruction snapshots periodically.
        if recon_dir is not None and int(args.save_recon_every) > 0 and ((step + 1) % int(args.save_recon_every) == 0):
            with torch.no_grad():
                x0 = torch.from_numpy(obs_u8).to(device).float().unsqueeze(0) / 255.0
                a0 = torch.tensor([int(action)], device=device, dtype=torch.long)
                phi0 = online.encode(x0)
                pred01 = online.decode_next(phi0, a0).squeeze(0).detach().cpu().numpy()  # [1,84,84]
            pred_u8 = np.clip(pred01 * 255.0, 0.0, 255.0).astype(np.uint8)
            tgt_u8 = next_u8.astype(np.uint8)  # [1,84,84]

            # Side-by-side comparison: [pred | target]
            comp = np.concatenate([pred_u8, tgt_u8], axis=2)  # [1,84,168]
            xy = tuple(env.agent_pos)
            base = os.path.join(recon_dir, f"step_{step+1:08d}_a{int(action)}_pos{xy[0]}_{xy[1]}")
            _save_gray_u8(base, comp)

        # episode boundary by fixed length
        ep_step += 1
        done = bool(terminated or truncated)
        if ep_step >= int(args.episode_len):
            done = True

        # intrinsic reward from TARGET psi(s_t)
        r_int = intrinsic_reward_from_target_psi(target, obs_u8, device, beta=float(args.beta), eps=float(args.eps))

        replay.add(obs_u8, int(action), float(r_int), next_u8, bool(done))

        # coverage stats
        visited_cum.add(tuple(env.agent_pos))
        visited_win.add(tuple(env.agent_pos))
        counts[env.agent_pos[1], env.agent_pos[0]] += 1
        coverage.append(len(visited_cum))
        coverage_windowed.append(len(visited_win))

        # periodic reset for windowed coverage
        reset_k = int(getattr(args, "coverage_reset_interval", 0))
        if reset_k > 0 and ((step + 1) % reset_k == 0):
            visited_win = set()
            visited_win.add(tuple(env.agent_pos))

        obs_u8 = next_u8

        if done:
            obs_rgb, _ = env.reset(seed=int(rng.integers(0, 2**31 - 1)))
            obs_u8 = preprocess_obs(obs_rgb)
            ep_step = 0
            # count reset position
            visited_cum.add(tuple(env.agent_pos))
            visited_win.add(tuple(env.agent_pos))
            counts[env.agent_pos[1], env.agent_pos[0]] += 1

        # learn
        if len(replay) >= int(args.learning_starts) and (step % int(args.train_freq) == 0):
            q_losses: list[float] = []
            sr_losses: list[float] = []
            recon_losses: list[float] = []
            total_losses: list[float] = []
            for _ in range(int(args.updates_per_step)):
                batch = replay.sample(int(args.batch_size), rng)
                ob_t = torch.from_numpy(batch.obs).to(device).float() / 255.0
                nob_t = torch.from_numpy(batch.next_obs).to(device).float() / 255.0
                a_t = torch.from_numpy(batch.actions).to(device).long()
                r_t = torch.from_numpy(batch.rewards).to(device).float()
                d_t = torch.from_numpy(batch.dones).to(device).float()

                # online forward at s
                phi_s, psi_s, pred_next, q_s = online(ob_t, a_t)
                q_sa = q_s.gather(1, a_t.view(-1, 1)).squeeze(1)

                with torch.no_grad():
                    # Double DQN
                    phi_next_online = online.encode(nob_t)
                    q_next_online = online.q(phi_next_online)
                    a_star = torch.argmax(q_next_online, dim=1)
                    phi_next_target = target.encode(nob_t)
                    q_next_target = target.q(phi_next_target).gather(1, a_star.view(-1, 1)).squeeze(1)
                    y = r_t + (1.0 - d_t) * (float(args.gamma_q) * q_next_target)

                q_loss = huber(q_sa, y)

                # SR loss: ψ(s) ≈ φ_target(s') + γ_sf (1-d) ψ_target(s')
                with torch.no_grad():
                    phi_sn_t = target.encode(nob_t)
                    psi_sn_t = target.psi(phi_sn_t.detach())
                    sr_target = phi_sn_t + float(args.gamma_sf) * (1.0 - d_t).unsqueeze(1) * psi_sn_t

                sr_loss = F.mse_loss(psi_s, sr_target)

                # recon loss: predict next frame (single-channel) of s'
                recon_target = nob_t  # [B,1,84,84] in [0,1]
                recon_loss = F.mse_loss(pred_next, recon_target)

                loss = float(args.w_q) * q_loss + float(args.w_sr) * sr_loss + float(args.w_recon) * recon_loss

                optim.zero_grad(set_to_none=True)
                loss.backward()
                nn.utils.clip_grad_norm_(online.parameters(), max_norm=10.0)
                optim.step()

                q_losses.append(float(q_loss.item()))
                sr_losses.append(float(sr_loss.item()))
                recon_losses.append(float(recon_loss.item()))
                total_losses.append(float(loss.item()))

            # Cache last averaged losses for logging
            if q_losses:
                last_q_loss = float(np.mean(q_losses))
                last_sr_loss = float(np.mean(sr_losses))
                last_recon_loss = float(np.mean(recon_losses))
                last_total_loss = float(np.mean(total_losses))

        # target update
        if (step + 1) % int(args.target_update_freq) == 0:
            target.load_state_dict(online.state_dict())

        log_every = int(getattr(args, "log_every", 0))
        if log_every > 0 and ((step + 1) % log_every == 0):
            cov = int(coverage[-1])
            cov_max = int(env.height * env.width)
            q_str = "NA" if last_q_loss is None else f"{last_q_loss:.6f}"
            sr_str = "NA" if last_sr_loss is None else f"{last_sr_loss:.6f}"
            rec_str = "NA" if last_recon_loss is None else f"{last_recon_loss:.6f}"
            tot_str = "NA" if last_total_loss is None else f"{last_total_loss:.6f}"
            print(
                f"[Step {step+1}] coverage={cov}/{cov_max} eps={eps:.3f} r_int={r_int:.6f} | "
                f"loss_total={tot_str} q={q_str} sr={sr_str} recon={rec_str}"
            )

    env.close()

    # save outputs
    np.save(os.path.join(args.out_dir, "coverage.npy"), np.asarray(coverage, dtype=np.int32))
    np.save(os.path.join(args.out_dir, "coverage_windowed.npy"), np.asarray(coverage_windowed, dtype=np.int32))
    np.save(os.path.join(args.out_dir, "counts.npy"), counts.astype(np.int32))

    if plt is not None:
        fig, ax = plt.subplots(figsize=(8, 4.5))
        ax.plot(np.asarray(coverage, dtype=np.float32), linewidth=2.0)
        ax.set_xlabel("Env steps")
        ax.set_ylabel("Coverage (#unique states)")
        ax.grid(alpha=0.3)
        fig.tight_layout()
        fig.savefig(os.path.join(args.out_dir, "coverage.png"), dpi=160)
        plt.close(fig)

    print(f"[OK] Saved to: {args.out_dir}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

