import numpy as np
import json
from scipy.special import expit
from sklearn.preprocessing import StandardScaler
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch.nn as nn
import torch.optim as optim
from rff import RFFRBFMap, _PointXYLazyPhi, _LevelBufferLazyPhi

import torch
import time

class StreamingRFFCoresetXY_LazyPhi:
    """
    Streaming coreset using RFF features, computed lazily:

    - add(x,y): store x,y only (phi=None)
    - compact(): when buffer reaches size m, compute Phi for these m points ONCE,
                cache into each point, then do sign-compaction using Phi.
    - points pushed upward keep their cached phi, so no recomputation later.
    """

    def __init__(
        self,
        buffer_size: int,
        rff_map: RFFRBFMap,
        randomized: bool = False,
        seed: Optional[int] = None,
    ):
        if buffer_size < 2:
            raise ValueError("buffer_size must be >= 2")

        self.m = int(buffer_size)
        self.rff = rff_map
        self.randomized = bool(randomized)
        self.rng = np.random.default_rng(seed)

        self.levels: List[_LevelBufferLazyPhi] = [_LevelBufferLazyPhi(pts=[], weight=1)]

    def add(self, x: np.ndarray, y: float) -> None:
        x = np.asarray(x, dtype=np.float64).reshape(-1)
        y = float(y)
        self._push(0, _PointXYLazyPhi(x=x, y=y, phi=None))

    def _push(self, h: int, pt: _PointXYLazyPhi) -> None:
        while h >= len(self.levels):
            self.levels.append(_LevelBufferLazyPhi(pts=[], weight=1 << h))

        buf = self.levels[h]
        buf.pts.append(pt)

        if len(buf.pts) >= self.m:
            self._compact(h)

    def _ensure_phi_cached(self, pts: List[_PointXYLazyPhi]) -> np.ndarray:
        """
        Ensure all points in pts have cached phi.
        Return stacked Phi: (m, D)
        """
        m = len(pts)
        assert m == self.m

        # If already cached, just stack
        if all(p.phi is not None for p in pts):
            return np.stack([p.phi for p in pts], axis=0)

        # Otherwise compute in batch
        X = np.stack([p.x for p in pts], axis=0)   # (m, d)
        Phi = self.rff.transform(X)                # (m, D)

        for i in range(m):
            if pts[i].phi is None:
                pts[i].phi = Phi[i].copy()

        return Phi

    def _compact(self, h: int) -> None:
        buf = self.levels[h]
        assert len(buf.pts) == self.m

        # lazy compute + cache phi
        Phi = self._ensure_phi_cached(buf.pts)     # (m, D)

        sigma = signs_from_features(Phi)           # (m,)

        pos = np.flatnonzero(sigma > 0)
        neg = np.flatnonzero(sigma < 0)

        if self.randomized:
            choose_pos = bool(self.rng.integers(0, 2))
        else:
            choose_pos = (len(pos) <= len(neg))

        keep_idx = pos if choose_pos else neg
        kept = [buf.pts[i] for i in keep_idx]

        # push kept points upward with cached phi
        for p in kept:
            self._push(h + 1, p)

        buf.pts.clear()

    def finalize(self) -> List[Tuple[np.ndarray, float, float, Optional[np.ndarray]]]:
        """
        Returns list of (x, y, weight, phi).
        NOTE: phi may be None if a point never went through compaction.
        """
        out = []
        for lvl in self.levels:
            w = float(lvl.weight)
            for pt in lvl.pts:
                out.append((pt.x, pt.y, w, pt.phi))
        return out

    def finalize_with_phi(self) -> List[Tuple[np.ndarray, float, float, np.ndarray]]:
        """
        Returns list of (x, y, weight, phi), guaranteeing phi is not None.
        (fills missing phi by batch transform per level)
        """
        out = []
        for lvl in self.levels:
            if len(lvl.pts) > 0:
                # fill missing phi for this level
                X = np.stack([p.x for p in lvl.pts], axis=0)
                Phi = self.rff.transform(X)
                for i, p in enumerate(lvl.pts):
                    if p.phi is None:
                        p.phi = Phi[i].copy()

            w = float(lvl.weight)
            for pt in lvl.pts:
                out.append((pt.x, pt.y, w, pt.phi))
        return out



def signs_from_features(Phi: np.ndarray) -> np.ndarray:
    """
    Equivalent to running Algorithm-1 on Gram=Phi Phi^T,
    but without constructing Gram.
    """
    Phi = np.asarray(Phi, dtype=np.float64)
    m, D = Phi.shape

    sigma = np.empty(m, dtype=np.int8)
    sigma[0] = 1

    v = Phi[0].copy()  # (D,)

    for i in range(1, m):
        s = float(v @ Phi[i])
        sigma[i] = 1 if s <= 0.0 else -1
        v += float(sigma[i]) * Phi[i]

    return sigma