"""
Graph Network-based Simulator with U-Net grid features.
"""

import math
from typing import Dict, Optional, Sequence, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import jraph

from .utils import NodeType

from .base import BaseModel
from .utils import build_mlp


class CellIndexer(hk.Module):
    """Converts continuous positions into integer cell indices."""

    def __init__(self, bounds: jnp.ndarray, num_cells: Sequence[int]):
        super().__init__()
        self._bounds_min = bounds[:, 0].astype(jnp.float32)
        self._bounds_max = bounds[:, 1].astype(jnp.float32)
        self._num_cells = jnp.array(num_cells, dtype=jnp.int32)

        cell_size = (self._bounds_max - self._bounds_min) / self._num_cells.astype(
            jnp.float32
        )
        self._cell_size = jnp.where(cell_size == 0, 1.0, cell_size)

    def __call__(self, positions: jnp.ndarray) -> jnp.ndarray:
        inds = (positions - self._bounds_min) / self._cell_size
        inds = jnp.floor(inds).astype(jnp.int32)
        return jnp.clip(inds, 0, self._num_cells - 1)


class CellSum(hk.Module):
    """Sums node features into a dense D-dimensional grid."""

    def __init__(self, num_cells: Sequence[int]):
        super().__init__()
        self.num_cells = tuple(int(c) for c in num_cells)
        self.total_cells = math.prod(self.num_cells)

        strides = []
        for d in range(len(self.num_cells)):
            suffix = self.num_cells[d + 1 :]
            strides.append(math.prod(suffix) if suffix else 1)
        self.strides = jnp.array(strides, dtype=jnp.int32)

    def _flatten_indices(self, indices: jnp.ndarray) -> jnp.ndarray:
        return jnp.sum(indices * self.strides[None, :], axis=-1)

    def __call__(self, indices: jnp.ndarray, values: jnp.ndarray) -> jnp.ndarray:
        flat_idx = self._flatten_indices(indices)
        summed = jax.ops.segment_sum(values, flat_idx, self.total_cells)

        feature_dims = values.shape[1:]
        return summed.reshape(self.num_cells + feature_dims)


class CellLinearScatter(hk.Module):
    r"""Deposits per–particle features on a dense grid with
    bilinear (2-D) / trilinear (3-D) weights.

    Args
    ----
    bounds     : (D,2) array, min / max of the domain
    num_cells  : Sequence[int], number of cells along each dim
    """

    def __init__(self, bounds: jnp.ndarray, num_cells: Sequence[int]):
        super().__init__()
        self._bmin = bounds[:, 0].astype(jnp.float32)
        self._bmax = bounds[:, 1].astype(jnp.float32)
        self._cells = jnp.array(num_cells, dtype=jnp.int32)

        cell_sz = (self._bmax - self._bmin) / self._cells.astype(jnp.float32)
        self._cell = jnp.where(cell_sz == 0, 1.0, cell_sz)

        self._shape = tuple(int(c) for c in num_cells)
        strides: list[int] = []
        for d in range(len(self._shape)):
            suffix = self._shape[d + 1 :]
            strides.append(math.prod(suffix) if suffix else 1)
        self._strides = jnp.array(strides, dtype=jnp.int32)
        self._tot_cells = math.prod(self._shape)

    def _flat(self, idx: jnp.ndarray) -> jnp.ndarray:
        return jnp.sum(idx * self._strides[None, :], axis=-1)

    def _scatter2d(self, pos, val):
        uv = (pos - self._bmin) / self._cell
        i0 = jnp.floor(uv).astype(jnp.int32)
        frac = uv - i0.astype(jnp.float32)
        i1 = jnp.clip(i0 + 1, 0, self._cells - 1)
        i0 = jnp.clip(i0, 0, self._cells - 1)

        x0, y0 = i0[:, 0], i0[:, 1]
        x1, y1 = i1[:, 0], i1[:, 1]
        wx, wy = frac[:, 0, None], frac[:, 1, None]

        idxs = jnp.stack(
            [
                self._flat(jnp.stack([x0, y0], 1)),
                self._flat(jnp.stack([x1, y0], 1)),
                self._flat(jnp.stack([x0, y1], 1)),
                self._flat(jnp.stack([x1, y1], 1)),
            ],
            1,
        )
        wts = jnp.concatenate(
            [
                (1 - wx) * (1 - wy),
                wx * (1 - wy),
                (1 - wx) * wy,
                wx * wy,
            ],
            1,
        )

        idxs = idxs.reshape(-1)
        accum = (val[:, None, :] * wts[:, :, None]).reshape(-1, val.shape[-1])
        grid = jax.ops.segment_sum(accum, idxs, self._tot_cells)

        return grid.reshape(self._shape + val.shape[1:])

    def _scatter3d(self, pos, val):
        uvw = (pos - self._bmin) / self._cell
        i0 = jnp.floor(uvw).astype(jnp.int32)
        frac = uvw - i0.astype(jnp.float32)
        i1 = jnp.clip(i0 + 1, 0, self._cells - 1)
        i0 = jnp.clip(i0, 0, self._cells - 1)

        x0, y0, z0 = i0[:, 0], i0[:, 1], i0[:, 2]
        x1, y1, z1 = i1[:, 0], i1[:, 1], i1[:, 2]
        wx, wy, wz = frac[:, 0, None], frac[:, 1, None], frac[:, 2, None]

        def flat(x, y, z):
            return self._flat(jnp.stack([x, y, z], 1))

        idxs = jnp.stack(
            [
                flat(x0, y0, z0),
                flat(x1, y0, z0),
                flat(x0, y1, z0),
                flat(x1, y1, z0),
                flat(x0, y0, z1),
                flat(x1, y0, z1),
                flat(x0, y1, z1),
                flat(x1, y1, z1),
            ],
            1,
        )

        wts = jnp.concatenate(
            [
                (1 - wx) * (1 - wy) * (1 - wz),
                wx * (1 - wy) * (1 - wz),
                (1 - wx) * wy * (1 - wz),
                wx * wy * (1 - wz),
                (1 - wx) * (1 - wy) * wz,
                wx * (1 - wy) * wz,
                (1 - wx) * wy * wz,
                wx * wy * wz,
            ],
            1,
        )

        idxs = idxs.reshape(-1)
        accum = (val[:, None, :] * wts[:, :, None]).reshape(-1, val.shape[-1])
        grid = jax.ops.segment_sum(accum, idxs, self._tot_cells)

        return grid.reshape(self._shape + val.shape[1:])

    def __call__(self, positions: jnp.ndarray, values: jnp.ndarray) -> jnp.ndarray:
        if positions.shape[1] == 2:
            return self._scatter2d(positions, values)
        if positions.shape[1] == 3:
            return self._scatter3d(positions, values)
        raise NotImplementedError("Only 2-D or 3-D positions are supported.")


class GridSampler(hk.Module):
    r"""Samples a dense grid at arbitrary particle positions with
    order-1 (bi/tri-linear) interpolation.

    Args
    ----
    bounds : (D, 2)  array with [min, max] per spatial dimension.
    num_cells : Sequence[int] – number of grid cells per dimension
    """

    def __init__(self, bounds: jnp.ndarray, num_cells: Sequence[int]):
        super().__init__()
        self._bmin = bounds[:, 0].astype(jnp.float32)
        self._bmax = bounds[:, 1].astype(jnp.float32)
        self._cells = jnp.array(num_cells, dtype=jnp.int32)

        cell_size = (self._bmax - self._bmin) / self._cells.astype(jnp.float32)
        # protect against zero-sized dimensions
        self._cell = jnp.where(cell_size == 0, 1.0, cell_size)

    def _bilinear(self, grid: jnp.ndarray, pos: jnp.ndarray) -> jnp.ndarray:
        """grid: (H, W, C),  pos: (N, 2)"""
        # fractional cell coordinates
        uv = (pos - self._bmin) / self._cell
        i0 = jnp.floor(uv).astype(jnp.int32)
        frac = uv - i0.astype(jnp.float32)
        i1 = jnp.clip(i0 + 1, 0, self._cells - 1)
        i0 = jnp.clip(i0, 0, self._cells - 1)

        # unpack for clarity
        x0, y0 = i0[:, 0], i0[:, 1]
        x1, y1 = i1[:, 0], i1[:, 1]
        wx, wy = frac[:, 0, None], frac[:, 1, None]

        # gather the four neighbours
        v00 = grid[x0, y0]
        v10 = grid[x1, y0]
        v01 = grid[x0, y1]
        v11 = grid[x1, y1]

        # bilinear blend
        return (
            (1 - wx) * (1 - wy) * v00
            + wx * (1 - wy) * v10
            + (1 - wx) * wy * v01
            + wx * wy * v11
        )

    def _trilinear(self, grid: jnp.ndarray, pos: jnp.ndarray) -> jnp.ndarray:
        """grid: (D, H, W, C),  pos: (N, 3)"""
        uvw = (pos - self._bmin) / self._cell
        i0 = jnp.floor(uvw).astype(jnp.int32)
        frac = uvw - i0.astype(jnp.float32)
        i1 = jnp.clip(i0 + 1, 0, self._cells - 1)
        i0 = jnp.clip(i0, 0, self._cells - 1)

        x0, y0, z0 = i0[:, 0], i0[:, 1], i0[:, 2]
        x1, y1, z1 = i1[:, 0], i1[:, 1], i1[:, 2]
        wx, wy, wz = (frac[:, 0, None], frac[:, 1, None], frac[:, 2, None])

        def g(x, y, z):
            return grid[x, y, z]

        # eight-corner blend
        return (
            (1 - wx) * (1 - wy) * (1 - wz) * g(x0, y0, z0)
            + wx * (1 - wy) * (1 - wz) * g(x1, y0, z0)
            + (1 - wx) * wy * (1 - wz) * g(x0, y1, z0)
            + wx * wy * (1 - wz) * g(x1, y1, z0)
            + (1 - wx) * (1 - wy) * wz * g(x0, y0, z1)
            + wx * (1 - wy) * wz * g(x1, y0, z1)
            + (1 - wx) * wy * wz * g(x0, y1, z1)
            + wx * wy * wz * g(x1, y1, z1)
        )

    def __call__(self, grid: jnp.ndarray, positions: jnp.ndarray) -> jnp.ndarray:
        if positions.shape[1] == 2:
            return self._bilinear(grid, positions)
        if positions.shape[1] == 3:
            return self._trilinear(grid, positions)
        raise NotImplementedError(
            f"GridSampler only supports 2-D or 3-D, got pdim={positions.shape[1]}"
        )


class ConvBlock(hk.Module):
    """A 2×(Conv3×3 → InstanceNorm → ReLU) block."""

    def __init__(self, num_spatial_dims: int, filters: int, name: str):
        super().__init__(name=name)
        self.num_spatial_dims = num_spatial_dims
        self.filters = filters
        self.inorm_kwargs = dict(create_scale=True, create_offset=True)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        conv1 = hk.ConvND(
            num_spatial_dims=self.num_spatial_dims,
            output_channels=self.filters,
            kernel_shape=3,
            stride=1,
            padding="SAME",
            name="conv1",
        )(x)
        in1 = hk.InstanceNorm(**self.inorm_kwargs)(conv1)
        r1 = jax.nn.relu(in1)

        conv2 = hk.ConvND(
            num_spatial_dims=self.num_spatial_dims,
            output_channels=self.filters,
            kernel_shape=3,
            stride=1,
            padding="SAME",
            name="conv2",
        )(r1)
        in2 = hk.InstanceNorm(**self.inorm_kwargs)(conv2)
        return jax.nn.relu(in2)


class GNLayer(hk.Module):
    """One independent Graph Network block (edge-&-node update)."""

    def __init__(self, latent: int, blocks: int, name: str = "gn_layer"):
        super().__init__(name=name)
        self.edge_mlp = build_mlp(latent, latent, blocks)
        self.node_mlp = build_mlp(latent, latent, blocks)
        self._gn = jraph.GraphNetwork(
            update_edge_fn=self._update_edge,
            update_node_fn=self._update_node,
        )

    def _update_edge(self, eb, s, r, _):
        return self.edge_mlp(jnp.concatenate([s, r, eb], -1))

    def _update_node(self, nb, _, agg, __):
        return self.node_mlp(jnp.concatenate([nb, agg], -1))

    def __call__(self, g: jraph.GraphsTuple) -> jraph.GraphsTuple:
        return self._gn(g)


class CoRGI(BaseModel):
    """Graph Network-based Simulator with U-Net grid feature integration."""

    def __init__(
        self,
        particle_dimension: int,
        latent_size: int,
        blocks_per_step: int,
        num_mp_steps: int,
        particle_type_embedding_size: int,
        num_particle_types: int = NodeType.SIZE,
        bounds: jnp.ndarray = jnp.array([[0.0, 1.0], [0.0, 1.0]]),
        num_cells: Optional[Sequence[int]] = None,
        unet_features: Optional[Sequence[int]] = None,
        skip_connections: bool = True,
        down_skips: Sequence[bool] = (True, True, True),
        up_skips: Sequence[bool] = (True, True, True),
    ):
        super().__init__()
        if num_cells is None or unet_features is None:
            raise ValueError("`num_cells` and `unet_features` must be provided")
        self.pdim = particle_dimension
        self.latent = latent_size
        self.blocks = blocks_per_step
        self.steps = num_mp_steps
        self.num_types = num_particle_types
        self.bounds = bounds
        self.cells = tuple(int(c) for c in num_cells)
        self.unet_feats = tuple(int(f) for f in unet_features)
        self.unet_layers = len(self.unet_feats)
        self.total_cells = math.prod(self.cells)
        self.skip_connections = skip_connections
        self.down_skips = down_skips
        self.up_skips = up_skips

        # embedding + encoder/decoder mlp
        self.embedding = hk.Embed(self.num_types, particle_type_embedding_size)
        self.enc_node = build_mlp(self.latent, self.latent, self.blocks)
        self.enc_edge = build_mlp(self.latent, self.latent, self.blocks)
        self.dec = build_mlp(self.latent, self.pdim, self.blocks, is_layer_norm=False)

        # shared mlp for message passing updates
        self.gns_down = [
            GNLayer(self.latent, self.blocks, name=f"gn_down_{i}")
            for i in range(self.steps)
        ]
        self.gns_up = [
            GNLayer(self.latent, self.blocks, name=f"gn_up_{i}")
            for i in range(self.steps)
        ]

        self.up_proj = hk.Linear(self.latent, name="final_processor_input_proj")
        self.scatter = CellLinearScatter(self.bounds, self.cells)
        self.sampler = GridSampler(self.bounds, self.cells)

        # u-net down path
        self.down_convs = [
            ConvBlock(self.pdim, f, name=f"unet_enc_{i}")
            for i, f in enumerate(self.unet_feats)
        ]
        self.pool = hk.AvgPool(
            window_shape=[2] * self.pdim + [1],
            strides=[2] * self.pdim + [1],
            padding="SAME",
            name="unet_pool",
        )
        # u-net up path
        self.up_tconvs = [
            hk.ConvNDTranspose(
                num_spatial_dims=self.pdim,
                output_channels=self.unet_feats[-(i + 1)],
                kernel_shape=2,
                stride=2,
                padding="SAME",
                name=f"unet_dec_up_{i}",
            )
            for i in range(self.unet_layers)
        ]
        self.up_convs = [
            ConvBlock(self.pdim, f, name=f"unet_dec_conv_{i}")
            for i, f in enumerate(reversed(self.unet_feats))
        ]

    # message-passing update functions
    def _update_edge(self, eb, s, r, _):
        return self.edge_mlp(jnp.concatenate([s, r, eb], -1))

    def _update_node(self, nb, _, agg, __):
        return self.node_mlp(jnp.concatenate([nb, agg], -1))

    def _encoder(self, g: jraph.GraphsTuple) -> jraph.GraphsTuple:
        return g._replace(
            nodes=self.enc_node(g.nodes),
            edges=self.enc_edge(g.edges),
        )

    def _processor_down(self, g: jraph.GraphsTuple) -> jraph.GraphsTuple:
        out = g
        for gn in self.gns_down:
            res = gn(out)
            out = out._replace(
                nodes=out.nodes + res.nodes,
                edges=out.edges + res.edges,
            )
        return out

    def _processor_up(self, g: jraph.GraphsTuple) -> jraph.GraphsTuple:
        g = g._replace(nodes=self.up_proj(g.nodes))
        out = g
        for gn in self.gns_up:
            res = gn(out)
            out = out._replace(
                nodes=out.nodes + res.nodes,
                edges=out.edges + res.edges,
            )
        return out

    def _decoder(self, g: jraph.GraphsTuple) -> jnp.ndarray:
        return self.dec(g.nodes)

    def _transform(
        self, features: Dict[str, jnp.ndarray], ptype: jnp.ndarray
    ) -> Tuple[jraph.GraphsTuple, jnp.ndarray, jnp.ndarray]:
        N = features["vel_hist"].shape[0]
        node_feats = [
            features[k].reshape(N, -1).astype(jnp.float32)
            for k in ["vel_hist", "vel_mag", "bound", "force"]
            if k in features
        ]
        edge_feats = [
            features[k].reshape(len(features["senders"]), -1).astype(jnp.float32)
            for k in ["rel_disp", "rel_dist"]
            if k in features
        ]
        if not node_feats or not edge_feats:
            raise ValueError("Missing required node/edge features")
        g = jraph.GraphsTuple(
            nodes=jnp.concatenate(node_feats, axis=-1),
            edges=jnp.concatenate(edge_feats, axis=-1),
            receivers=features["receivers"],
            senders=features["senders"],
            n_node=jnp.array([N], jnp.int32),
            n_edge=jnp.array([len(features["senders"])], jnp.int32),
            globals=None,
        )
        positions = features["abs_pos"][:, -1]
        if positions.shape != (N, self.pdim):
            raise ValueError(f"abs_pos shape {positions.shape} != {(N, self.pdim)}")
        return g, ptype, positions

    def __call__(
        self, sample: Tuple[Dict[str, jnp.ndarray], jnp.ndarray]
    ) -> Dict[str, jnp.ndarray]:
        feats, ptype = sample
        graph, ptype, pos = self._transform(feats, ptype)

        # optional type embedding
        if self.num_types > 1:
            e = self.embedding(ptype)
            graph = graph._replace(nodes=jnp.concatenate([graph.nodes, e], -1))

        # encode & down-path message passing
        eg = self._encoder(graph)
        dg = self._processor_down(eg)

        # grid binning
        grid = self.scatter(pos, dg.nodes)

        # u-net down path
        downs = []
        orig, curr = grid, grid
        for i, conv in enumerate(self.down_convs):
            if self.down_skips[i]:
                inp = jnp.concatenate([orig, curr], axis=-1)
            else:
                inp = curr
            out = conv(inp)
            downs.append(out)
            orig = self.pool(orig)
            curr = self.pool(out)

        # u-net up path
        up = downs[-1]
        for i, (tconv, conv, skip) in enumerate(
            zip(self.up_tconvs, self.up_convs, reversed(downs))
        ):
            if self.up_skips[i]:
                up_in = jnp.concatenate([up, skip], axis=-1)
            else:
                up_in = up
            up = tconv(up_in)
            up = conv(up)

        # flatten & gather back to particles
        gathered = self.sampler(up, pos)

        # splice into graph, up-path message passing & decode
        new_nodes = jnp.concatenate([dg.nodes, gathered], axis=-1)
        fg = dg._replace(nodes=new_nodes)
        ug = self._processor_up(fg)
        acc = self._decoder(ug)

        return {"acc": acc}
