import torch

class DiffGrid:
    masked_value = -1e9

    def __init__(self, size, batch_size, log_reward, device='cpu', eps=0.1, seed=None, T=None):
        self.device = device
        self.batch_size = int(batch_size)
        self.width = int(size)
        self.height = int(size)
        self.T = int(2 * max(self.width, self.height)) if T is None else int(T)

        self._log_reward = log_reward
        self.eps = float(eps)
        self.seed = seed
        if seed is not None:
            self.g = torch.Generator(device=self.device)
            self.g.manual_seed(seed)

        self.batch_ids = torch.arange(self.batch_size, device=self.device)

        self.pos = torch.zeros((self.batch_size, 2), device=device, dtype=torch.get_default_dtype())
        self.t = torch.zeros((self.batch_size,), device=device, dtype=torch.long)

        self.actions = torch.tensor([[1, 0], [-1, 0], [0, 1], [0, -1], [0, 0]],
                                    dtype=torch.get_default_dtype(), device=device)

        self.forward_mask = torch.ones((self.batch_size, 5), device=device)
        self.backward_mask = torch.zeros((self.batch_size, 5), device=device)

        self.stopped = torch.zeros(self.batch_size, device=self.device, dtype=torch.bool)
        self.is_initial = torch.ones(self.batch_size, device=self.device, dtype=torch.bool)

        self.update_forward_mask()
        self.update_backward_mask()

    @property
    def state_dim(self): return 3  # (x, y, t_norm)

    @property
    def forward_action_dim(self): return 5

    @property
    def backward_action_dim(self): return 5

    @torch.no_grad()
    def obs(self):
        t_norm = (self.t.float() / max(1, self.T)).unsqueeze(-1)
        return torch.cat([self.pos, t_norm], dim=-1)

    @torch.no_grad()
    def log_reward(self):
        return self._log_reward(self)

    @torch.no_grad()
    def reset(self, batch_size=None):
        self.batch_size = int(batch_size) if batch_size is not None else self.batch_size
        self.batch_ids = torch.arange(self.batch_size, device=self.device)

        self.pos = torch.zeros((self.batch_size, 2), device=self.device, dtype=torch.get_default_dtype())
        self.t = torch.zeros((self.batch_size,), device=self.device, dtype=torch.long)

        self.forward_mask = torch.ones((self.batch_size, 5), device=self.device)
        self.backward_mask = torch.zeros((self.batch_size, 5), device=self.device)

        self.stopped = torch.zeros(self.batch_size, device=self.device, dtype=torch.bool)
        self.is_initial = torch.ones(self.batch_size, device=self.device, dtype=torch.bool)

        self.update_forward_mask()
        self.update_backward_mask()

    @torch.no_grad()
    def set_full_grid_T(self):

        nx = 2 * self.width + 1
        ny = 2 * self.height + 1
        self.batch_size = nx * ny
        self.batch_ids = torch.arange(self.batch_size, device=self.device)

        X, Y = torch.meshgrid(
            torch.arange(-self.width,  self.width + 1, device=self.device, dtype=torch.get_default_dtype()),
            torch.arange(-self.height, self.height + 1, device=self.device, dtype=torch.get_default_dtype()),
            indexing="ij"
        )
        self.pos = torch.stack([X.flatten(), Y.flatten()], dim=-1)
        self.t = torch.full((self.batch_size,), int(self.T), device=self.device, dtype=torch.long)

        self.forward_mask = torch.zeros((self.batch_size, 5), device=self.device)
        self.backward_mask = torch.zeros((self.batch_size, 5), device=self.device)

        self.stopped = (self.t == self.T)
        self.is_initial = torch.zeros_like(self.stopped)

        self.update_forward_mask()
        self.update_backward_mask()

    def get_forward_pol(self, net):
        logits = net(self.obs())
        if logits.shape[-1] != 5:
            raise ValueError(f"Esperava logits com 5 saídas, recebi {logits.shape[-1]}")
        x = logits * self.forward_mask + (1. - self.forward_mask) * self.masked_value
        return torch.softmax(x, dim=-1)

    def get_backward_pol(self, net):
        logits = net(self.obs())
        if logits.shape[-1] != 5:
            raise ValueError(f"Esperava logits com 5 saídas, recebi {logits.shape[-1]}")
        x = logits * self.backward_mask + (1. - self.backward_mask) * self.masked_value
        return torch.softmax(x, dim=-1)

    @torch.no_grad()
    def update_forward_mask(self):
        next_pos = self.pos[:, None, :] + self.actions[None, :, :]  # (B,5,2)
        in_bounds = (
            (next_pos[..., 0] >= -self.width)  & (next_pos[..., 0] <= self.width) &
            (next_pos[..., 1] >= -self.height) & (next_pos[..., 1] <= self.height)
        )
        time_ok = (self.t < self.T).unsqueeze(-1).expand(self.batch_size, self.actions.shape[0])
        self.forward_mask = (in_bounds & time_ok).to(self.forward_mask.dtype)

        self.stopped = (self.t == self.T)
        self.is_initial = (self.t == 0) & (self.pos == 0).all(dim=1)

    @torch.no_grad()
    def update_backward_mask(self):

        prev_pos = self.pos[:, None, :] - self.actions[None, :, :]  # (B,5,2)
        in_bounds = (
            (prev_pos[..., 0] >= -self.width)  & (prev_pos[..., 0] <= self.width) &
            (prev_pos[..., 1] >= -self.height) & (prev_pos[..., 1] <= self.height)
        )
        time_ok = (self.t > 0).unsqueeze(-1).expand(self.batch_size, self.actions.shape[0])

        tprev = (self.t - 1).clamp_min(0).view(-1, 1).to(torch.long)   # (B,1)
        d_prev = prev_pos.abs().sum(dim=-1).to(torch.long)              # (B,5)  L1 = |x'|+|y'|
        dist_ok = (d_prev <= tprev)

        self.backward_mask = (in_bounds & time_ok & dist_ok).to(self.backward_mask.dtype)

    @torch.no_grad()
    def get_forward_actions(self, pol, training=True):

        uniform_pol = torch.where(self.forward_mask == 1., 1., 0.)
        denom = uniform_pol.sum(dim=1, keepdim=True).clamp_min(1.0)
        uniform_pol = uniform_pol / denom

        eps = self.eps if training else 0.0
        exp_pol = pol * (1 - eps) + eps * uniform_pol

        if self.seed is not None:
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True, generator=self.g)
        else:
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True)
        return actions.squeeze(-1)

    @torch.no_grad()
    def get_backward_actions(self, pol, greedy=False):
        if greedy:
            actions = pol.argmax(dim=-1, keepdim=True)
        else:
            if self.seed is not None:
                actions = torch.multinomial(pol, num_samples=1, replacement=True, generator=self.g)
            else:
                actions = torch.multinomial(pol, num_samples=1, replacement=True)
        return actions.squeeze(-1)

    @torch.no_grad()
    def apply(self, indices):
        idx = indices.long()
        can_move = (self.t < self.T)
        if can_move.any():
            delta = self.actions[idx[can_move]]
            upos = self.pos.clone()
            upos[can_move] = upos[can_move] + delta
            upos[..., 0].clamp_(-self.width,  self.width)
            upos[..., 1].clamp_(-self.height, self.height)
            self.pos = upos
            self.t[can_move] += 1
        self.update_forward_mask()
        self.update_backward_mask()

    @torch.no_grad()
    def backward(self, indices):
        idx = indices.long()
        can_move = (self.t > 0)
        if can_move.any():
            delta = self.actions[idx[can_move]]
            upos = self.pos.clone()
            upos[can_move] = upos[can_move] - delta
            upos[..., 0].clamp_(-self.width,  self.width)
            upos[..., 1].clamp_(-self.height, self.height)
            self.pos = upos
            self.t[can_move] -= 1
        self.update_forward_mask()
        self.update_backward_mask()
        return indices

    @torch.no_grad()
    def _grid_shape(self) -> tuple[int, int]:
        nx = 2 * self.width + 1
        ny = 2 * self.height + 1
        return nx, ny

    @torch.no_grad()
    def _to_grid_image(self, values: torch.Tensor) -> torch.Tensor:
        assert values.shape[0] == self.batch_size, "values precisa ter shape [batch_size]"
        nx, ny = self._grid_shape()
        img = torch.full((ny, nx), float("nan"), dtype=values.dtype)

        # mapeia (x,y) -> índices de grade (coluna=x, linha=y)
        x_idx = (self.pos[:, 0].round().long() + self.width).clamp_(0, nx - 1)
        y_idx = (self.pos[:, 1].round().long() + self.height).clamp_(0, ny - 1)

        # atenção: em arrays 2D, 1a dim é linha (y), 2a é coluna (x)
        img[y_idx, x_idx] = values.detach().to(img.dtype).cpu()
        return img  # [ny, nx]

    @torch.no_grad()
    def plot_samples(self, samples, kind="scatter", bins=None, cmap="viridis", title=None):
        import numpy as np
        import matplotlib.pyplot as plt

        samples = samples.detach().cpu()
        x = samples[:, 0].numpy()
        y = samples[:, 1].numpy()

        plt.figure(figsize=(6, 5))

        if kind == "scatter":
            plt.scatter(x, y, s=5, alpha=0.5)
            plt.xlim(-self.width, self.width)
            plt.ylim(-self.height, self.height)
            plt.gca().set_aspect("equal", adjustable="box")
        else:
            if bins is None:
                nx, ny = self._grid_shape()
                bins = [nx, ny]
            extent = [-self.width, self.width, -self.height, self.height]
            H, xedges, yedges = np.histogram2d(x, y, bins=bins,
                                               range=[[-self.width, self.width],
                                                      [-self.height, self.height]])
            H = H.T  # para ficar [ny, nx] com origem em baixo
            im = plt.imshow(H, origin="lower", extent=extent, cmap=cmap)
            plt.colorbar(im)
            plt.gca().set_aspect("equal", adjustable="box")

        if title:
            plt.title(title)
        plt.xlabel("x")
        plt.ylabel("y")
        plt.tight_layout()
        plt.show()

    @torch.no_grad()
    def plot_log_r_hat(
        self,
        log_r_hat: torch.Tensor,
        *,
        compare_true: bool = False,
        prob: bool = False,
        kind: str = "imshow",     # "imshow" ou "contour"
        levels: int = 100,
        cmap: str = "viridis",
        share_vrange: bool = True,
        title = None
    ):

        import numpy as np
        import matplotlib.pyplot as plt
        img_est = self._to_grid_image(log_r_hat)
        data_est = img_est.exp().numpy() if prob else img_est.numpy()

        extent = [-self.width, self.width, -self.height, self.height]  # [xmin, xmax, ymin, ymax]
        true_vec = self.log_reward().exp()
        vmax = true_vec.max()
        vmin = true_vec.min()
        if compare_true:
            true_vec = self.log_reward()
            img_true = self._to_grid_image(true_vec)
            data_true = img_true.exp().numpy() if prob else img_true.numpy()

            fig, axs = plt.subplots(1, 3, figsize=(15, 4.5), sharex=True, sharey=True)

            if kind == "contour":
                cs0 = axs[0].contourf(data_est, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax,
                                      extent=extent)  # extent só é aceito no imshow; aplicamos via set
            else:
                im0 = axs[0].imshow(data_est, origin="lower", cmap=cmap, vmin=vmin, vmax=vmax, extent=extent)
                fig.colorbar(im0, ax=axs[0], fraction=0.046, pad=0.04)
            axs[0].set_title("Estimado " + ("r̂" if prob else "log r̂"))

            if kind == "contour":
                cs1 = axs[1].contourf(data_true, levels=levels, cmap=cmap, vmin=vmin, vmax=vmax)
            else:
                im1 = axs[1].imshow(data_true, origin="lower", cmap=cmap, vmin=vmin, vmax=vmax, extent=extent)
                fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04)
            axs[1].set_title("Verdadeiro " + ("r" if prob else "log r"))

            err = np.abs(
                np.exp(img_est.cpu().numpy()) - np.exp(img_true.cpu().numpy())
            )
            if kind == "contour":
                cs2 = axs[2].contourf(err, levels=levels, cmap=cmap)
            else:
                im2 = axs[2].imshow(err, origin="lower", cmap=cmap, extent=extent)
                fig.colorbar(im2, ax=axs[2], fraction=0.046, pad=0.04)
            axs[2].set_title("|r̂ - r|")

            for ax in axs:
                ax.set_xlabel("x"); ax.set_ylabel("y")
            plt.tight_layout()
            plt.show()

        else:
            plt.figure(figsize=(6, 5))
            if kind == "contour":
                import numpy as np
                finite = np.isfinite(data_est)
                levels_ = levels
                # if finite.any():
                #     vmin, vmax = data_est[finite].min(), data_est[finite].max()
                # else:
                #     vmin = vmax = None
                # para contourf precisamos de (X,Y): usamos índices do array com extent visual
                ny, nx = data_est.shape
                xs = np.linspace(-self.width, self.width, nx)
                ys = np.linspace(-self.height, self.height, ny)
                X, Y = np.meshgrid(xs, ys)
                cs = plt.contourf(X, Y, data_est, levels=levels_, cmap=cmap, vmin=vmin, vmax=vmax)
                plt.colorbar(cs)
            else:
                im = plt.imshow(data_est, origin="lower", cmap=cmap, extent=extent, vmin=vmin, vmax=vmax)
                #plt.colorbar(im)
                plt.axis("off")

            if title:
                plt.title(title)
            #plt.xlabel("x"); plt.ylabel("y")
            plt.tight_layout()
            plt.show()

# --------------------------- TESTS ---------------------------

def _uniform_from_mask(mask: torch.Tensor) -> torch.Tensor:
    row_sums = mask.sum(dim=1, keepdim=True)
    assert (row_sums > 0).all(), "Linha sem ações válidas (soma=0)"
    return mask / row_sums

@torch.no_grad()
def _uniform_backward_toward_origin(env: DiffGrid) -> torch.Tensor:
    B, A = env.batch_size, env.actions.shape[0]
    prev_pos = env.pos[:, None, :] - env.actions[None, :, :]
    d_now = env.pos.abs().sum(dim=-1).view(-1, 1)              # (B,1)
    d_prev = prev_pos.abs().sum(dim=-1)                        # (B,5)
    decrease = (d_prev == (d_now - 1).clamp(min=0))            # (B,5)
    # Válidas pela máscara do ambiente
    valid = env.backward_mask.bool()
    cand = (decrease & valid)

    at_origin = (d_now.view(-1) == 0)
    if at_origin.any():
        only_stay = torch.zeros_like(cand)
        only_stay[:, 4] = True
        cand[at_origin] = only_stay[at_origin]

    empty = (~cand).all(dim=1)
    cand[empty] = valid[empty]

    p = cand.float()
    p = p / p.sum(dim=1, keepdim=True).clamp_min(1.0)
    return p  # (B,5)

def test_1_set_full_grid_T_correct():
    size = 4
    env = DiffGrid(size, batch_size=1, log_reward=lambda e: torch.zeros(e.batch_size), seed=123)
    env.set_full_grid_T()

    nx, ny = 2*size + 1, 2*size + 1
    assert env.batch_size == nx * ny, "batch_size != (2W+1)*(2H+1)"
    assert (env.t == env.T).all(), "t precisa ser T em todo o batch"

    uniq = set((int(x), int(y)) for x, y in env.pos.cpu().numpy())
    assert len(uniq) == nx * ny, "posições (x,y) repetidas ou faltando"

    assert float(env.forward_mask.sum().item()) == 0.0, "forward_mask deveria ser 0 em t==T"

    prev_pos = env.pos[:, None, :] - env.actions[None, :, :]
    in_bounds = (
        (prev_pos[..., 0] >= -size) & (prev_pos[..., 0] <= size) &
        (prev_pos[..., 1] >= -size) & (prev_pos[..., 1] <= size)
    )
    tprev = (env.t - 1).clamp_min(0).view(-1, 1).to(torch.long)
    d_prev = prev_pos.abs().sum(dim=-1).to(torch.long)
    dist_ok = (d_prev <= tprev)
    expected = (in_bounds & (env.t.view(-1, 1) > 0) & dist_ok).to(env.backward_mask.dtype)
    assert torch.equal(env.backward_mask, expected), "backward_mask incorreta em t==T"
    return "OK - test_1_set_full_grid_T_correct"

def test_2_backward_returns_origin_with_uniform_decrease():
    size = 3
    env = DiffGrid(size, batch_size=(2*size + 1) * (2*size + 1),
                     log_reward=lambda e: torch.zeros(e.batch_size),
                     seed=7)
    env.set_full_grid_T()  # todos (x,y) em t=T

    for _ in range(env.T):
        pol = _uniform_backward_toward_origin(env)  # Bx5
        acts = env.get_backward_actions(pol, greedy=False)
        env.backward(acts)

    assert (env.t == 0).all(), "t final deveria ser 0"
    assert torch.all(env.pos == 0), "pos final deveria ser (0,0) para todos"
    return "OK - test_2_backward_returns_origin_with_uniform_decrease"

class _UpOnlyNet(torch.nn.Module):
    def forward(self, obs):
        # logits: [→, ←, ↑, ↓, stay]
        B = obs.shape[0]
        out = torch.full((B, 5), -10.0, dtype=obs.dtype, device=obs.device)
        out[:, 2] = 10.0
        return out

def test_3_forward_up_bias_respects_bounds():
    size = 5
    env = DiffGrid(size, batch_size=1,
                     log_reward=lambda e: torch.zeros(e.batch_size),
                     eps=0.0, seed=0)
    net = _UpOnlyNet()
    for _ in range(env.T):
        pol = env.get_forward_pol(net)                   # Bx5
        acts = env.get_forward_actions(pol, training=True)
        y_before = float(env.pos[0, 1].item())
        env.apply(acts)
        y_after = float(env.pos[0, 1].item())
        # nunca ultrapassa o topo (y <= +H)
        assert y_after <= env.height + 1e-6
        # se já estava no topo, a ação ↑ deve estar mascarada (logo sai outra ação)
        if y_before >= env.height - 1e-9:
            assert acts.item() != 2, "bounds violated"
    return "OK - test_3_forward_up_bias_respects_bounds"

if __name__ == "__main__":
    print(test_1_set_full_grid_T_correct())
    print(test_2_backward_returns_origin_with_uniform_decrease())
    print(test_3_forward_up_bias_respects_bounds())