#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fast observation sampling utilities with:
  - Per-z graph caching (CSR adjacency, fluid candidates, coords)
  - Random sampling using cached CSR for hop expansion
  - Drone swarm sampler that simulates drones moving over the graph (random walk)

This avoids rebuilding adjacency each batch and can simulate trajectory-like observations.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Tuple, Optional

import numpy as np


@dataclass
class GraphCache:
    indptr: np.ndarray   # (N+1,) int64
    indices: np.ndarray  # (2E,) int32 (symmetric adjacency)
    fluid_cand: np.ndarray  # (<=N,) int32 indices of fluid nodes (or all nodes if not provided)
    coords: np.ndarray   # (N,2) float32
    N: int


class GraphCacheManager:
    """Caches CSR adjacency and candidate arrays per z value."""
    _cache: Dict[float, GraphCache] = {}

    @classmethod
    def get(cls, graph_struct: Dict[str, np.ndarray], z_value: float, *, fluid_only: bool = True) -> GraphCache:
        if z_value in cls._cache:
            return cls._cache[z_value]

        coords = np.asarray(graph_struct["original_coordinates"], dtype=np.float32)
        N = int(coords.shape[0])
        send = np.asarray(graph_struct["o2o_senders"], dtype=np.int32)
        recv = np.asarray(graph_struct["o2o_receivers"], dtype=np.int32)
        # symmetric adjacency
        e_src = np.concatenate([send, recv], axis=0).astype(np.int32)
        e_dst = np.concatenate([recv, send], axis=0).astype(np.int32)

        # CSR build
        indptr = np.zeros(N + 1, dtype=np.int64)
        for s in e_src:
            indptr[s + 1] += 1
        np.cumsum(indptr, out=indptr)
        indices = np.empty_like(e_dst)
        fill = indptr.copy()
        for s, d in zip(e_src, e_dst):
            p = fill[s]
            indices[p] = d
            fill[s] += 1

        if fluid_only and ("node_types" in graph_struct):
            ntypes = np.asarray(graph_struct["node_types"], dtype=np.int32)
            fluid = (ntypes == 1)
            fluid_cand = np.where(fluid)[0].astype(np.int32)
            if fluid_cand.size == 0:
                fluid_cand = np.arange(N, dtype=np.int32)
        else:
            fluid_cand = np.arange(N, dtype=np.int32)

        gc = GraphCache(indptr=indptr, indices=indices, fluid_cand=fluid_cand, coords=coords, N=N)
        cls._cache[z_value] = gc
        return gc


def expand_mask_hops(base_mask: np.ndarray, indptr: np.ndarray, indices: np.ndarray, hops: int) -> np.ndarray:
    """Expand a boolean mask to include up to `hops` hops neighbors using CSR adjacency.
    Vectorized and fast; returns float32 mask (0/1)."""
    m = base_mask.astype(np.bool_).copy()
    if hops <= 0:
        return m.astype(np.float32)
    # BFS layers
    frontier = m.copy()
    for _ in range(hops):
        # gather neighbors of current frontier
        # For each node u with frontier[u]=True, mark its neighbors
        neighbor_mask = np.zeros_like(m)
        us = np.nonzero(frontier)[0]
        for u in us:
            start, end = indptr[u], indptr[u + 1]
            nbrs = indices[start:end]
            neighbor_mask[nbrs] = True
        new = neighbor_mask & (~m)
        if not new.any():
            break
        m |= new
        frontier = new
    return m.astype(np.float32)


def sample_random_obs(gc: GraphCache, *, k: int, hops: int, rng: np.random.Generator,
                      base_weights: Optional[np.ndarray] = None) -> np.ndarray:
    """Randomly sample k base nodes (optionally weighted in candidate set), then expand by hops.
    Returns float32 mask shape (N,).
    """
    cand = gc.fluid_cand
    if cand.size == 0:
        cand = np.arange(gc.N, dtype=np.int32)
    if base_weights is None:
        pick = rng.choice(cand, size=min(k, cand.size), replace=False)
    else:
        w = base_weights[cand]
        # avoid negative/zero
        w = np.clip(w, 1e-12, None)
        w = w.astype(np.float64)
        w = w / np.sum(w)
        pick = rng.choice(cand, size=min(k, cand.size), replace=False, p=w)
    base_mask = np.zeros(gc.N, dtype=np.float32)
    base_mask[pick] = 1.0
    if hops > 0:
        return expand_mask_hops(base_mask.astype(bool), gc.indptr, gc.indices, hops)
    return base_mask


class DroneSwarmSampler:
    """Simulate drones moving on graph nodes via random walks.
    Observations are unions of BFS balls around drone positions.
    Deterministic given (seed, z, step).
    """

    def __init__(self, gc: GraphCache, *, num_drones: int = 10, hops_radius: int = 1,
                 move_prob: float = 0.9, seed: int = 0) -> None:
        self.gc = gc
        self.num_drones = int(max(1, num_drones))
        self.hops_radius = int(max(0, hops_radius))
        self.move_prob = float(np.clip(move_prob, 0.0, 1.0))
        self.seed = int(seed)

    def _rng_for(self, step: int, drone_id: int) -> np.random.Generator:
        # derive a stable seed per (seed, step, drone_id)
        s = (self.seed * 0x9E3779B1 + step * 0x85EBCA6B + drone_id * 0xC2B2AE35) & 0xFFFFFFFF
        return np.random.default_rng(int(s))

    def _next_pos(self, pos: int, rng: np.random.Generator) -> int:
        if rng.random() > self.move_prob:
            return pos
        start, end = self.gc.indptr[pos], self.gc.indptr[pos + 1]
        if end <= start:
            return pos
        nbrs = self.gc.indices[start:end]
        return int(rng.choice(nbrs))

    def mask_for_step(self, step: int) -> np.ndarray:
        gc = self.gc
        # initial positions: choose random fluid candidates deterministically
        mask = np.zeros(gc.N, dtype=np.float32)
        # choose starting positions
        init_rng = np.random.default_rng((self.seed ^ (step * 0xA24BAED)) & 0xFFFFFFFF)
        if gc.fluid_cand.size == 0:
            starts = init_rng.integers(0, gc.N, size=self.num_drones, dtype=np.int32)
        else:
            # sample without replacement if possible
            replace = gc.fluid_cand.size < self.num_drones
            starts = init_rng.choice(gc.fluid_cand, size=self.num_drones, replace=replace)

        # evolve one step (deterministic per drone)
        positions = []
        for i, p0 in enumerate(starts):
            r = self._rng_for(step, i)
            p1 = self._next_pos(int(p0), r)
            positions.append(p1)

        # BFS balls around positions
        base = np.zeros(gc.N, dtype=np.float32)
        base[np.array(positions, dtype=np.int32)] = 1.0
        if self.hops_radius > 0:
            return expand_mask_hops(base.astype(bool), gc.indptr, gc.indices, self.hops_radius)
        return base

    def mask_for_span(self, start_step: int, steps: int) -> np.ndarray:
        """Union of observations over a trajectory of given length (steps >= 1).
        Deterministic given (seed, z, start_step).
        """
        steps = max(1, int(steps))
        gc = self.gc
        # starting positions as in mask_for_step
        init_rng = np.random.default_rng((self.seed ^ (start_step * 0xA24BAED)) & 0xFFFFFFFF)
        if gc.fluid_cand.size == 0:
            starts = init_rng.integers(0, gc.N, size=self.num_drones, dtype=np.int32)
        else:
            replace = gc.fluid_cand.size < self.num_drones
            starts = init_rng.choice(gc.fluid_cand, size=self.num_drones, replace=replace)

        positions = np.array(starts, dtype=np.int32)
        union_base = np.zeros(gc.N, dtype=np.float32)
        # include start positions
        union_base[positions] = 1.0
        # walk for (steps-1) subsequent steps
        for offset in range(steps - 1):
            # derive rng per (start_step+offset+1, drone)
            for i in range(self.num_drones):
                r = self._rng_for(start_step + offset + 1, i)
                positions[i] = self._next_pos(int(positions[i]), r)
            union_base[positions] = 1.0

        if self.hops_radius > 0:
            return expand_mask_hops(union_base.astype(bool), gc.indptr, gc.indices, self.hops_radius)
        return union_base




