import numpy as np
import json
from scipy.special import expit
from sklearn.preprocessing import StandardScaler
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch.nn as nn
import torch.optim as optim

import torch
import time

class RFFRBFMap:
    """
    RFF for RBF kernel:
      k(x,y)=exp(-||x-y||^2/(2*sigma^2)) ≈ phi(x)^T phi(y)
    phi(x)=sqrt(2/D)*cos(xW+b)
    W ~ N(0, 1/sigma^2), b ~ Uniform(0,2pi)
    """

    def __init__(self, d_in: int, D: int = 256, sigma: float = 3.0, seed: int = 0):
        self.d_in = int(d_in)
        self.D = int(D)
        self.sigma = float(sigma)

        rng = np.random.default_rng(seed)
        self.W = rng.normal(0.0, 1.0 / self.sigma, size=(self.d_in, self.D)).astype(np.float64)
        self.b = (2.0 * np.pi) * rng.uniform(0.0, 1.0, size=(self.D,)).astype(np.float64)
        self.scale = np.sqrt(2.0 / self.D)

    def transform(self, X: np.ndarray) -> np.ndarray:
        X = np.asarray(X, dtype=np.float64)
        return self.scale * np.cos(X @ self.W + self.b)  # (n, D)


@dataclass
class _PointXYLazyPhi:
    x: np.ndarray
    y: float
    phi: Optional[np.ndarray] = None   # (D,) or None


@dataclass
class _LevelBufferLazyPhi:
    pts: List[_PointXYLazyPhi]
    weight: int