from __future__ import annotations

import argparse
import json
import os
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Optional plotting (allow --help without matplotlib)
try:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
except Exception:  # pragma: no cover
    plt = None  # type: ignore

# Optional resize backends
try:
    import cv2  # type: ignore
except Exception:  # pragma: no cover
    cv2 = None  # type: ignore
try:
    from PIL import Image  # type: ignore
except Exception:  # pragma: no cover
    Image = None  # type: ignore

# Ensure project root for absolute imports
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_THIS_DIR)
if _ROOT_DIR not in os.sys.path:
    os.sys.path.insert(0, _ROOT_DIR)

# Optional dependency: minigrid (allow `--help` without installing)
try:
    from visual_gridworld.visual_minigrid import SimpleEnv  # type: ignore
except Exception:
    SimpleEnv = None  # type: ignore

from visual_gridworld.sp_dqn_model import SPDQNNet, SPDQNConfig
from visual_gridworld.sp_dqn_replay import ReplayBuffer


def _rgb_to_gray_u8(img: np.ndarray) -> np.ndarray:
    if img is None:
        raise ValueError("Input image is None")
    if img.ndim == 3 and img.shape[2] >= 3:
        return (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]).astype(np.uint8)
    return img.astype(np.uint8)


def preprocess_obs(obs_rgb: np.ndarray) -> np.ndarray:
    """RGB uint8 -> uint8 [1,H,W] (grayscale + downsample to max side 128)."""
    gray = _rgb_to_gray_u8(obs_rgb)
    h, w = gray.shape[:2]
    max_side = 128
    step = max(1, int(np.ceil(max(h, w) / float(max_side))))
    if step > 1:
        gray = gray[::step, ::step]
    return gray.astype(np.uint8)[None, :, :]


def _save_gray_u8(path_no_ext: str, img_u8: np.ndarray) -> None:
    img2d = img_u8
    if img2d.ndim == 3 and img2d.shape[0] == 1:
        img2d = img2d[0]
    if img2d.ndim != 2:
        raise ValueError(f"Expected [H,W] or [1,H,W] uint8, got shape={img_u8.shape}")
    png_path = path_no_ext + ".png"
    if cv2 is not None:
        _ = cv2.imwrite(png_path, img2d)
        return
    if Image is not None:
        Image.fromarray(img2d).save(png_path)
        return
    np.save(path_no_ext + ".npy", img2d.astype(np.uint8))


def epsilon_by_step(step: int, eps_start: float, eps_end: float, eps_decay_steps: int) -> float:
    if step >= int(eps_decay_steps):
        return float(eps_end)
    frac = float(step) / float(max(1, int(eps_decay_steps)))
    return float(eps_start + (eps_end - eps_start) * frac)


@torch.no_grad()
def select_action_eps_greedy(
    online: SPDQNNet,
    obs_u8: np.ndarray,
    device: torch.device,
    *,
    eps: float,
    rng: np.random.Generator,
) -> int:
    if rng.random() < float(eps):
        return int(rng.integers(0, online.cfg.num_actions))
    x = torch.from_numpy(obs_u8).to(device).float().unsqueeze(0) / 255.0
    phi = online.encode(x)
    q = online.q(phi)
    return int(torch.argmax(q, dim=1).item())


@torch.no_grad()
def intrinsic_reward_sf_pf(
    target: SPDQNNet,
    obs_u8: np.ndarray,
    action: int,
    next_obs_u8: np.ndarray,
    device: torch.device,
    *,
    eps: float,
) -> float:
    """
    r_SF-PF^(consistent)(t) = 1/(||xi(s_{t+1})||_1+eps) - 1/(||psi(s_t,a_t)||_1+eps)
    Uses both PF (xi) and SF (psi) consistently, computed with target-network outputs.
    """
    x_t = torch.from_numpy(obs_u8).to(device).float().unsqueeze(0) / 255.0
    x_tp1 = torch.from_numpy(next_obs_u8).to(device).float().unsqueeze(0) / 255.0
    
    # ξ(s_{t+1}) from target PF-head
    phi_tp1 = target.encode(x_tp1)
    xi_tp1 = target.pf(phi_tp1.detach())  # [1,d]
    
    # ψ(s_t,a_t) from target SF-head (action-conditional)
    phi_t = target.encode(x_t)
    psi_all = target.sf_all(phi_t.detach())  # [1,A,d]
    psi_ta = psi_all[0, int(action), :]  # [d]

    inv_xi = 1.0 / (torch.norm(xi_tp1, p=1, dim=1).clamp_min(float(eps)))
    inv_psi = 1.0 / (torch.norm(psi_ta, p=1).clamp_min(float(eps)))
    return float((inv_xi - inv_psi).item())


def main() -> int:
    p = argparse.ArgumentParser(description="SP exploration (SF+PF) with DQN in visual_gridworld.")

    # env (match SR defaults)
    p.add_argument("--env_size", type=int, default=20)
    p.add_argument("--seed", type=int, default=1)
    p.add_argument("--episode_len", type=int, default=2000000)
    p.add_argument("--no_render", action="store_true", help="Disable human rendering.")

    # training (match SR defaults)
    p.add_argument("--total_steps", type=int, default=9000)
    p.add_argument("--learning_starts", type=int, default=100)
    p.add_argument("--replay_capacity", type=int, default=3000)
    p.add_argument("--batch_size", type=int, default=64)
    p.add_argument("--train_freq", type=int, default=4)
    p.add_argument("--updates_per_step", type=int, default=1)
    p.add_argument("--target_update_freq", type=int, default=200)
    p.add_argument("--lr", type=float, default=2.5e-4)
    p.add_argument("--rmsprop_eps", type=float, default=0.000009765625)
    p.add_argument(
        "--rmsprop_centered",
        action=argparse.BooleanOptionalAction,
        default=True,
    )

    # logging/saving (match SR defaults)
    p.add_argument("--log_every", type=int, default=500)
    p.add_argument("--save_recon_every", type=int, default=4000)

    # discounts
    p.add_argument("--gamma_q", type=float, default=0.99)
    p.add_argument("--gamma_sf", type=float, default=0.9)

    # epsilon schedule
    p.add_argument("--eps_start", type=float, default=1.0)
    p.add_argument("--eps_end", type=float, default=0.1)
    p.add_argument("--eps_decay_steps", type=int, default=1000)

    # intrinsic reward
    p.add_argument("--beta", type=float, default=0.1)
    p.add_argument("--eps", type=float, default=1e-6)

    # loss weights (keep shared same; add w_sf/w_pf)
    p.add_argument("--w_q", type=float, default=10.0)
    p.add_argument("--w_sf", type=float, default=10.0)
    p.add_argument("--w_pf", type=float, default=10.0)
    p.add_argument("--w_recon", type=float, default=0.01)

    # output
    p.add_argument("--out_dir", type=str, default=os.path.join("runs", "sp_intrinsic_vg"))
    # coverage analysis
    p.add_argument(
        "--coverage_reset_interval",
        type=int,
        default=0,
        help="If >0, also compute within-window unique-position coverage (visited reset every K env steps).",
    )

    args = p.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    with open(os.path.join(args.out_dir, "config.json"), "w", encoding="utf-8") as f:
        json.dump(vars(args), f, indent=2)

    rng = np.random.default_rng(int(args.seed))
    torch.manual_seed(int(args.seed))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if SimpleEnv is None:
        raise ModuleNotFoundError(
            "Missing optional dependency for visual_gridworld environment (minigrid). "
            "Please install it (e.g., `pip install minigrid`) to run sp_dqn_explore."
        )

    env = SimpleEnv(size=int(args.env_size), render_mode="rgb_array", highlight=False)
    obs_rgb, _ = env.reset(seed=int(args.seed))
    
    if obs_rgb is None:
        raise ValueError(f"Failed to get valid observation from environment. Try different seed or check environment setup.")
    
    obs_u8 = preprocess_obs(obs_rgb)
    in_h, in_w = int(obs_u8.shape[1]), int(obs_u8.shape[2])

    cfg = SPDQNConfig(in_channels=1, num_actions=int(env.action_space.n), feat_dim=1024, input_hw=(in_h, in_w))
    online = SPDQNNet(cfg).to(device)
    target = SPDQNNet(cfg).to(device)
    target.load_state_dict(online.state_dict())
    target.eval()
    for p0 in target.parameters():
        p0.requires_grad = False

    optim = torch.optim.RMSprop(
        online.parameters(),
        lr=float(args.lr),
        alpha=0.95,
        eps=float(args.rmsprop_eps),
        centered=bool(args.rmsprop_centered),
        momentum=0.0,
    )
    huber = nn.SmoothL1Loss(reduction="mean")

    replay = ReplayBuffer(int(args.replay_capacity), obs_shape=(1, in_h, in_w))

    recon_dir: str | None = None
    if int(args.save_recon_every) > 0:
        recon_dir = os.path.join(args.out_dir, "recon_samples")
        os.makedirs(recon_dir, exist_ok=True)

    visited_cum = set()
    visited_win = set()
    counts = np.zeros((env.height, env.width), dtype=np.int32)
    visited_cum.add(tuple(env.agent_pos))
    visited_win.add(tuple(env.agent_pos))
    counts[env.agent_pos[1], env.agent_pos[0]] += 1
    coverage = [len(visited_cum)]
    coverage_windowed = [len(visited_win)]

    ep_step = 0

    last_q_loss: float | None = None
    last_sf_loss: float | None = None
    last_pf_loss: float | None = None
    last_recon_loss: float | None = None
    last_total_loss: float | None = None

    for step in range(int(args.total_steps)):
        eps_t = epsilon_by_step(step, args.eps_start, args.eps_end, args.eps_decay_steps)
        action = select_action_eps_greedy(online, obs_u8, device, eps=eps_t, rng=rng)

        next_rgb, _r_ext, terminated, truncated, _info = env.step(int(action))
        next_u8 = preprocess_obs(next_rgb)

        # next_action for SARSA-style SF bootstrap (sampled at s_{t+1})
        eps_tp1 = epsilon_by_step(step + 1, args.eps_start, args.eps_end, args.eps_decay_steps)
        next_action = select_action_eps_greedy(online, next_u8, device, eps=eps_tp1, rng=rng)

        ep_step += 1
        done = bool(terminated or truncated) or (ep_step >= int(args.episode_len))

        # intrinsic reward from target outputs
        r_sfpf = intrinsic_reward_sf_pf(target, obs_u8, action, next_u8, device, eps=float(args.eps))
        r_aug = float(args.beta) * float(r_sfpf)

        replay.add(obs_u8, int(action), int(next_action), float(r_aug), next_u8, bool(done))

        # coverage stats
        visited_cum.add(tuple(env.agent_pos))
        visited_win.add(tuple(env.agent_pos))
        counts[env.agent_pos[1], env.agent_pos[0]] += 1
        coverage.append(len(visited_cum))
        coverage_windowed.append(len(visited_win))

        # periodic reset for windowed coverage
        reset_k = int(getattr(args, "coverage_reset_interval", 0))
        if reset_k > 0 and ((step + 1) % reset_k == 0):
            visited_win = set()
            visited_win.add(tuple(env.agent_pos))

        # optional recon snapshots: [pred | target]
        if recon_dir is not None and int(args.save_recon_every) > 0 and ((step + 1) % int(args.save_recon_every) == 0):
            with torch.no_grad():
                x0 = torch.from_numpy(obs_u8).to(device).float().unsqueeze(0) / 255.0
                phi0 = online.encode(x0)
                pred01 = online.recon_next(phi0).squeeze(0).detach().cpu().numpy()  # [1,84,84]
            pred_u8 = np.clip(pred01 * 255.0, 0.0, 255.0).astype(np.uint8)
            tgt_u8 = next_u8.astype(np.uint8)
            comp = np.concatenate([pred_u8, tgt_u8], axis=2)
            xy = tuple(env.agent_pos)
            base = os.path.join(recon_dir, f"step_{step+1:08d}_a{int(action)}_pos{xy[0]}_{xy[1]}")
            _save_gray_u8(base, comp)

        obs_u8 = next_u8

        if done:
            obs_rgb, _ = env.reset(seed=int(rng.integers(0, 2**31 - 1)))
            obs_u8 = preprocess_obs(obs_rgb)
            ep_step = 0
            visited_cum.add(tuple(env.agent_pos))
            visited_win.add(tuple(env.agent_pos))
            counts[env.agent_pos[1], env.agent_pos[0]] += 1

        # learn
        if len(replay) >= int(args.learning_starts) and (step % int(args.train_freq) == 0):
            q_losses: list[float] = []
            sf_losses: list[float] = []
            pf_losses: list[float] = []
            rec_losses: list[float] = []
            tot_losses: list[float] = []

            for _ in range(int(args.updates_per_step)):
                batch = replay.sample(int(args.batch_size), rng)
                ob_t = torch.from_numpy(batch.obs).to(device).float() / 255.0
                nob_t = torch.from_numpy(batch.next_obs).to(device).float() / 255.0
                a_t = torch.from_numpy(batch.actions).to(device).long()
                a_tp1 = torch.from_numpy(batch.next_actions).to(device).long()
                r_t = torch.from_numpy(batch.rewards).to(device).float()
                d_t = torch.from_numpy(batch.dones).to(device).float()

                # online φ/q + SF/PF heads (SF/PF use detach internally)
                phi_s = online.encode(ob_t)              # grads allowed (for Q/recon)
                q_s = online.q(phi_s)
                q_sa = q_s.gather(1, a_t.view(-1, 1)).squeeze(1)

                psi_all = online.sf_all(phi_s.detach())  # [B,A,d]
                psi_sa = psi_all[torch.arange(psi_all.size(0), device=device), a_t]  # [B,d]

                xi_s = online.pf(phi_s.detach())         # [B,d]
                # ξ(s_{t+1}) predicted by online PF head (detach encoder feature)
                phi_sn = online.encode(nob_t)
                xi_sn = online.pf(phi_sn.detach())

                # Q target (Double DQN)
                with torch.no_grad():
                    phi_next_online = online.encode(nob_t)
                    q_next_online = online.q(phi_next_online)
                    a_star = torch.argmax(q_next_online, dim=1)
                    phi_next_target = target.encode(nob_t)
                    q_next_target = target.q(phi_next_target).gather(1, a_star.view(-1, 1)).squeeze(1)
                    y = r_t + (1.0 - d_t) * (float(args.gamma_q) * q_next_target)
                q_loss = huber(q_sa, y)

                # SF TD target: φ(s_t) + γ ψ(s_{t+1},a_{t+1}) - ψ(s_t,a_t)
                with torch.no_grad():
                    phi_targ_s = target.encode(ob_t)
                    psi_all_targ_sn = target.sf_all(target.encode(nob_t).detach())
                    psi_targ_sn_atp1 = psi_all_targ_sn[torch.arange(psi_all_targ_sn.size(0), device=device), a_tp1]
                    sf_target = phi_targ_s + float(args.gamma_sf) * (1.0 - d_t).unsqueeze(1) * psi_targ_sn_atp1
                sf_loss = torch.mean(torch.sum((sf_target - psi_sa) ** 2, dim=1))

                # PF TD target: μ(s_{t+1}) + γ ξ(s_t), where μ(s) = φ(s)
                with torch.no_grad():
                    mu_targ_sn = target.encode(nob_t)  # μ(s_{t+1}) = φ(s_{t+1})
                    xi_targ_s = target.pf(target.encode(ob_t).detach())  # ξ(s_t)
                    pf_target = mu_targ_sn + float(args.gamma_sf) * (1.0 - d_t).unsqueeze(1) * xi_targ_s
                pf_loss = torch.mean(torch.sum((pf_target - xi_sn) ** 2, dim=1))

                # recon head: ŝ_{t+1} from φ(s_t)
                pred_next = online.recon_next(phi_s)
                recon_target = nob_t
                recon_loss = F.mse_loss(pred_next, recon_target)

                loss = (
                    float(args.w_q) * q_loss
                    + float(args.w_sf) * sf_loss
                    + float(args.w_pf) * pf_loss
                    + float(args.w_recon) * recon_loss
                )

                optim.zero_grad(set_to_none=True)
                loss.backward()
                nn.utils.clip_grad_norm_(online.parameters(), max_norm=10.0)
                optim.step()

                q_losses.append(float(q_loss.item()))
                sf_losses.append(float(sf_loss.item()))
                pf_losses.append(float(pf_loss.item()))
                rec_losses.append(float(recon_loss.item()))
                tot_losses.append(float(loss.item()))

            if q_losses:
                last_q_loss = float(np.mean(q_losses))
                last_sf_loss = float(np.mean(sf_losses))
                last_pf_loss = float(np.mean(pf_losses))
                last_recon_loss = float(np.mean(rec_losses))
                last_total_loss = float(np.mean(tot_losses))

        # target update
        if (step + 1) % int(args.target_update_freq) == 0:
            target.load_state_dict(online.state_dict())

        # periodic log
        log_every = int(getattr(args, "log_every", 0))
        if log_every > 0 and ((step + 1) % log_every == 0):
            cov = int(coverage[-1])
            cov_max = int(env.height * env.width)
            q_str = "NA" if last_q_loss is None else f"{last_q_loss:.6f}"
            sf_str = "NA" if last_sf_loss is None else f"{last_sf_loss:.6f}"
            pf_str = "NA" if last_pf_loss is None else f"{last_pf_loss:.6f}"
            rec_str = "NA" if last_recon_loss is None else f"{last_recon_loss:.6f}"
            tot_str = "NA" if last_total_loss is None else f"{last_total_loss:.6f}"
            print(
                f"[Step {step+1}] coverage={cov}/{cov_max} eps={eps_t:.3f} r_sfpf={r_sfpf:.6f} r_aug={r_aug:.6f} | "
                f"loss_total={tot_str} q={q_str} sf={sf_str} pf={pf_str} recon={rec_str}"
            )

    env.close()

    np.save(os.path.join(args.out_dir, "coverage.npy"), np.asarray(coverage, dtype=np.int32))
    np.save(os.path.join(args.out_dir, "coverage_windowed.npy"), np.asarray(coverage_windowed, dtype=np.int32))
    np.save(os.path.join(args.out_dir, "counts.npy"), counts.astype(np.int32))
    if plt is not None:
        fig, ax = plt.subplots(figsize=(8, 4.5))
        ax.plot(np.asarray(coverage, dtype=np.float32), linewidth=2.0)
        ax.set_xlabel("Env steps")
        ax.set_ylabel("Coverage (#unique states)")
        ax.grid(alpha=0.3)
        fig.tight_layout()
        fig.savefig(os.path.join(args.out_dir, "coverage.png"), dpi=160)
        plt.close(fig)

    print(f"[OK] Saved to: {args.out_dir}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

