import os
import sys
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

# Ensure project root for absolute imports
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_ROOT_DIR = os.path.dirname(_THIS_DIR)
if _ROOT_DIR not in sys.path:
    sys.path.insert(0, _ROOT_DIR)

from visual_gridworld.fpvr_model import FPVRVisualModel


def canonical_sf_target_mode(mode: str) -> str:
    """Canonicalize SR target mode naming used across scripts/configs."""
    return str(mode)


class FPVRVisualAgent:
    """Future-Past Visitation Redundancy (FPVR) agent (visual version).

    Mechanism:
    - φ: state features (CNN encoder output; stable scale)
    - φ̃ = ZCA(φ): whitened features (explicit ZCA transform)
    - ψ(s,a): successor features (slow time-scale future-visitation representation)
    - c (persistence representation): discounted accumulator of past whitened features (short time-scale)
    - redundancy: overlap score between c and ψ(s,a); prefer actions with lower redundancy
    """

    def __init__(
        self,
        obs_shape,
        n_actions,
        *,
        lr=0.001,
        phi_dim=128,
        psi_dim: int | None = None,
        sf_gamma=0.999,
        beta=1.0,
        capacity=10000,
        batch_size=64,
        update_after=100,
        update_every=1,
        sf_target_mode: str = "uniform_policy",
        whitening_update_every: int = 100,
        device=None,
    ):
        self.device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
        psi_dim = psi_dim if psi_dim is not None else phi_dim

        # Model and hyperparameters
        self.model = FPVRVisualModel(obs_shape, n_actions, phi_dim=phi_dim, psi_dim=psi_dim).to(self.device)
        self.sf_gamma = float(sf_gamma)
        self.beta = float(beta)
        self.n_actions = n_actions
        self.sf_target_mode = canonical_sf_target_mode(str(sf_target_mode))
        self.phi_dim = int(phi_dim)

        # Persistence representation c: discounted accumulator of past whitened features φ̃
        self.c = torch.zeros((1, psi_dim), device=self.device)

        # Replay buffer
        self._init_replay_buffer(obs_shape, capacity, batch_size, update_after, update_every, psi_dim)

        # Optimizer (model parameters only)
        self.opt = torch.optim.Adam(self.model.parameters(), lr=lr)

        # ZCA whitening statistics (no gradient; updated periodically)
        self.whitening_update_every = int(whitening_update_every)
        self.whitening_update_count = 0
        self.running_mean = torch.zeros(phi_dim, device=self.device)
        self.running_cov = torch.eye(phi_dim, device=self.device)
        self.whitening_matrix = torch.eye(phi_dim, device=self.device)  # W = Σ^{-1/2}

        # (verbose/vicreg/inv-dyn branches removed)

    def _update_whitening_matrix(self, phi_batch):
        """Update the ZCA whitening matrix (periodic; no gradients).

        ZCA whitening: φ̃ = W^T (φ - μ), where W = U S^{-1/2} U^T.
        This makes Cov(φ̃) ≈ I (decorrelated, unit-variance features).
        """
        with torch.no_grad():
            # EMA update of mean/covariance
            batch_mean = phi_batch.mean(dim=0)
            phi_centered = phi_batch - batch_mean
            batch_cov = (phi_centered.T @ phi_centered) / phi_batch.size(0)

            alpha = 0.01  # EMA rate
            self.running_mean = (1 - alpha) * self.running_mean + alpha * batch_mean
            self.running_cov = (1 - alpha) * self.running_cov + alpha * batch_cov

            # Recompute whitening matrix every N steps
            if self.whitening_update_count % self.whitening_update_every == 0:
                # ZCA whitening: W = U S^{-1/2} U^T
                U, S, _ = torch.svd(self.running_cov)
                S_inv_sqrt = 1.0 / torch.sqrt(S.clamp(min=1e-5))  # numerical guard
                self.whitening_matrix = U @ torch.diag(S_inv_sqrt) @ U.T

            self.whitening_update_count += 1

    def _apply_whitening(self, phi_raw):
        """Apply ZCA whitening: φ̃ = W^T (φ - μ).

        Args:
            phi_raw: [B, phi_dim] raw features
        Returns:
            phi_tilde: [B, phi_dim] whitened features
        """
        phi_centered = phi_raw - self.running_mean
        phi_tilde = phi_centered @ self.whitening_matrix
        return phi_tilde

    def _init_replay_buffer(self, obs_shape, capacity, batch_size, update_after, update_every, psi_dim):
        """Initialize the replay buffer."""
        c, h, w = obs_shape
        # Limit memory usage
        bytes_per = c * h * w  # uint8
        max_bytes = 512 * 1024 * 1024  # ~512MB
        max_capacity = max(10000, int(max_bytes // max(bytes_per, 1)))
        self.capacity = int(min(capacity, max_capacity))
        self.batch_size = int(batch_size)
        self.update_after = int(update_after)
        self.update_every = int(update_every)

        self.obs_buf = np.zeros((self.capacity, c, h, w), dtype=np.uint8)
        self.nxt_buf = np.zeros((self.capacity, c, h, w), dtype=np.uint8)
        self.act_buf = np.zeros((self.capacity,), dtype=np.int64)
        self.done_buf = np.zeros((self.capacity,), dtype=np.float32)
        self.nxt_act_buf = np.zeros((self.capacity,), dtype=np.int64)

        self.ptr = 0
        self.size = 0

    @torch.no_grad()
    def act(self, obs, greedy=False):
        """Select an action using FPVR (prefer lower future-past redundancy)."""
        if obs.ndim == 3:
            obs = np.expand_dims(obs, 0)

        obs_t = torch.from_numpy(obs).to(self.device)

        # φ_raw -> update whitening stats elsewhere; here just compute φ̃ and ψ̃(·)
        phi_raw = self.model(obs_t.float())
        phi_tilde = self._apply_whitening(phi_raw)  # ZCA whitening

        # ψ-head input is always φ̃ (stop-grad)
        psi_all = self.model.psi_head(phi_tilde.detach()).view(-1, self.n_actions, phi_tilde.size(-1))

        # redundancy in whitened SF space: cos(ψ̃, c)
        redundancy = self._compute_redundancy(psi_all)

        # Sample actions biased toward lower redundancy
        action = self._sample_action(redundancy, greedy)
        return int(action.item()), phi_tilde.squeeze(0).cpu().numpy()

    def _compute_redundancy(self, psi_all: torch.Tensor) -> torch.Tensor:
        """Future-past redundancy score: cos(ψ̃(s,a), c). Returns [B, A]."""
        c_norm = F.normalize(self.c, p=2, dim=1)  # [1,D]
        psi_flat = psi_all.view(-1, psi_all.size(2))
        psi_normed = F.normalize(psi_flat, p=2, dim=1).view(psi_all.size(0), psi_all.size(1), -1)
        return torch.einsum('bad,bd->ba', psi_normed, c_norm)

    def _sample_action(self, redundancy: torch.Tensor, greedy: bool = False) -> torch.Tensor:
        """Sample actions biased toward lower redundancy."""
        mu = redundancy.mean(dim=1, keepdim=True)
        sd = redundancy.std(dim=1, keepdim=True)
        redundancy_norm = (redundancy - mu) / (sd + 1e-6)
        logits = -self.beta * redundancy_norm
        dist = torch.distributions.Categorical(logits=logits)
        probs = dist.probs
        return torch.argmax(probs, dim=1) if greedy else dist.sample()

    def update_c_from_phi(self, phi_tilde_np, lambda_c=0.99):
        """Update persistence vector c: c <- lambda_c * c + φ̃."""
        with torch.no_grad():
            phi_tilde_t = torch.from_numpy(phi_tilde_np[None]).to(self.device)
            self.c.mul_(float(lambda_c)).add_(phi_tilde_t)

    def store(self, obs, action, next_obs, done, *, next_action: int | None = None):
        """Store a transition in the replay buffer."""
        self.obs_buf[self.ptr] = obs
        self.nxt_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = int(action)
        self.done_buf[self.ptr] = float(done)
        self.nxt_act_buf[self.ptr] = int(0 if next_action is None else next_action)
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def train_step(self, step_idx=0):
        """One training step: update successor features ψ (SR TD loss)."""
        if self.size < self.update_after or (step_idx % self.update_every) != 0:
            return None

        batch = self._sample_batch()
        phi_raw = self.model(batch['obs'].float())
        self._update_whitening_matrix(phi_raw.detach())
        phi_tilde = self._apply_whitening(phi_raw)
        psi_all = self.model.psi_head(phi_tilde.detach()).view(-1, self.n_actions, phi_tilde.size(-1))

        sr_loss = self._compute_sr_loss(batch, phi_tilde, psi_all)
        self.opt.zero_grad()
        sr_loss.backward()
        self.opt.step()
        return float(sr_loss.item())

    def _sample_batch(self):
        idxs = np.random.randint(0, self.size, size=self.batch_size)
        return {
            'obs': torch.from_numpy(self.obs_buf[idxs]).to(self.device),
            'nxt': torch.from_numpy(self.nxt_buf[idxs]).to(self.device),
            'act': torch.from_numpy(self.act_buf[idxs]).to(self.device),
            'nxt_act': torch.from_numpy(self.nxt_act_buf[idxs]).to(self.device),
            'done': torch.from_numpy(self.done_buf[idxs]).to(self.device),
        }

    def _compute_sr_loss(self, batch, phi_tilde, psi_all):
        """Compute the successor-feature TD loss (SR TD loss)."""
        psi_curr = psi_all[torch.arange(psi_all.size(0), device=self.device), batch['act']]

        with torch.no_grad():
            phi_raw_next = self.model(batch['nxt'].float())
            phi_tilde_next = self._apply_whitening(phi_raw_next)
            psi_next_all = self.model.psi_head(phi_tilde_next.detach()).view(-1, self.n_actions, phi_tilde_next.size(-1))

            if self.sf_target_mode == "min_redundancy":
                B = psi_next_all.size(0)
                c_used = self.c.expand(B, -1)
                c_norm = F.normalize(c_used, p=2, dim=1)
                psi_flat = psi_next_all.view(-1, psi_next_all.size(2))
                psi_normed = F.normalize(psi_flat, p=2, dim=1).view(B, self.n_actions, -1)
                redundancy_next = torch.einsum('bad,bd->ba', psi_normed, c_norm)
                mu = redundancy_next.mean(dim=1, keepdim=True)
                sd = redundancy_next.std(dim=1, keepdim=True)
                redundancy_z = (redundancy_next - mu) / (sd + 1e-6)
                a_min = torch.argmin(redundancy_z, dim=1)
                psi_exp = psi_next_all[torch.arange(B, device=self.device), a_min]
            elif self.sf_target_mode == "current_policy":
                psi_exp = psi_next_all[torch.arange(psi_next_all.size(0), device=self.device), batch['nxt_act']]
            else:
                psi_exp = psi_next_all.mean(dim=1)

        sr_target = phi_tilde.detach() + self.sf_gamma * (1.0 - batch['done'].unsqueeze(-1)) * psi_exp
        return F.mse_loss(psi_curr, sr_target.detach())

