from dataclasses import dataclass, field
from typing import Optional, Tuple, List, Union, Callable, Dict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.kernel_approximation import RBFSampler

__all__ = [
    "RFFTTAConfig",
    "RFFTTAAdapter",
]

@dataclass
class RFFTTAConfig:
    rff_dim: int = 256
    gamma: Optional[float] = None
    hidden_dims: List[int] = field(default_factory=lambda: [256, 256])
    regressor_epochs: int = 50
    regressor_lr: float = 1e-3
    batch_size: int = 256
    tta_lr: float = 0.1
    tta_steps: int = 200
    tau: Optional[float] = None
    eval_every: int = 5
    patience: int = 10
    consecutive_accept: int = 2
    verbose: bool = False
    seed: int = 42


class RFFTTAAdapter:
    def __init__(self, config: RFFTTAConfig = None, device: str = "cpu"):
        self.config = config or RFFTTAConfig()
        self.device = torch.device(device)
        self.rff = self.regressor = self.s_mean = self.s_std = self.gamma =None
    
    def fit(self, X_train: np.ndarray, 
                  S_train: np.ndarray,
                  calibration_split: float = 0.05,
                  alpha: float = 0.05) -> "RFFTTAAdapter":
        cfg = self.config
        
        if calibration_split > 0:
            n_calib = int(len(X_train) * calibration_split)
            if n_calib < 10:
                raise ValueError(f"calibration_split={calibration_split} yields only {n_calib} samples. Need at least 10.")
            
            rng = np.random.default_rng(cfg.seed)
            indices = rng.permutation(len(X_train))
            
            calib_idx = indices[:n_calib]
            train_idx = indices[n_calib:]
            
            X_calib, S_calib = X_train[calib_idx], S_train[calib_idx]
            X_train, S_train = X_train[train_idx], S_train[train_idx]

        X_flat = X_train.reshape(-1, X_train.shape[-1])
        self.gamma = self._median_heuristic(X_flat)
        self.rff = RBFSampler(gamma=self.gamma, 
                              n_components=cfg.rff_dim, 
                              random_state=cfg.seed)

        M, N, D = X_train.shape
        Z_train = self.rff.fit_transform(X_flat).reshape(M, N, -1).mean(axis=1) # (M, D_RFF)
        
        S = torch.as_tensor(S_train, dtype=torch.float32) # (M, H)
        Z = torch.as_tensor(Z_train, dtype=torch.float32) # (M, D_RFF)
        self.s_mean, self.s_std = S.mean(0), S.std(0) + 1e-8
        S_norm = (S - self.s_mean) / self.s_std
        
        self.regressor = self._build_regressor(S.shape[1], Z.shape[1])
        self.regressor_losses = self._train_regressor(S_norm, Z)
        
        self.s_mean = self.s_mean.to(self.device)
        self.s_std = self.s_std.to(self.device)

        if calibration_split > 0:
            self._calibrate_tau_from_data(X_calib, S_calib, alpha)

        return self
    
    def adapt(self, s_init: torch.Tensor, x_obs: np.ndarray, 
              config: RFFTTAConfig = None) -> dict:
        cfg = config or self.config
        adaptive = cfg.tau is not None
        
        z_obs = torch.tensor(self.rff.transform(x_obs.reshape(-1, x_obs.shape[-1])).mean(0),
                            dtype=torch.float32, device=self.device)
        
        s_norm = (s_init.to(self.device) - self.s_mean) / self.s_std
        
        if adaptive:
            with torch.no_grad():
                loss = self._compute_loss(s_norm, z_obs)
            if loss <= cfg.tau:
                return {'best_s': s_init.cpu(), 'gate_passed': True, 
                        'n_steps': 0, 'stop_reason': 'gate', 'losses': [loss]}
        
        # optimize
        s_opt = nn.Parameter(s_norm.clone())
        losses = []
        n_evals = [0]

        def closure():
            optimizer.zero_grad()
            loss = self._compute_loss(s_opt, z_obs)
            loss.backward()
            losses.append(loss.item())
            n_evals[0] += 1
            return loss

        optimizer = torch.optim.LBFGS(
            [s_opt],
            lr=1.0,                     
            max_iter=cfg.tta_steps,     
            max_eval=cfg.tta_steps * 4, 
            tolerance_grad=1e-5,        
            tolerance_change=1e-5,      
            history_size=10,            
            line_search_fn='strong_wolfe' 
        )

        optimizer.step(closure)
        
        return self._result(s_opt.detach(), False, n_evals[0], 'converged', losses)
    
    def _calibrate_tau_from_data(self, X_calib: np.ndarray, 
                                       S_calib: np.ndarray, 
                                       alpha: float = 0.05):
        losses = []
        
        for i in range(len(X_calib)):
            z_obs = torch.as_tensor(
                self.rff.transform(X_calib[i].reshape(-1, X_calib[i].shape[-1])).mean(0),
                dtype=torch.float32, device=self.device
            ) # (D_RFF)
            
            s = torch.as_tensor(S_calib[i], dtype=torch.float32, device=self.device)
            s_norm = (s - self.s_mean) / self.s_std
            
            with torch.no_grad():
                loss = self._compute_loss(s_norm, z_obs).item()
            losses.append(loss)
        
        self.calibrated_tau = float(np.quantile(losses, 1 - alpha))
        self.config.tau = self.calibrated_tau  # Update config
        
        print(f"Calibrated tau={self.calibrated_tau:.6f} (alpha={alpha}, n_calib={len(X_calib)})")
        print(f"  Loss stats: mean={np.mean(losses):.6f}, std={np.std(losses):.6f}")

    def _compute_loss(self, s_norm, z_obs):
        z_pred = self.regressor(s_norm.unsqueeze(0)).squeeze()
        return torch.sum((z_pred - z_obs) ** 2)
    
    def _result(self, best_s_norm, gate_passed, n_steps, reason, losses):
        return {
            'best_s': (best_s_norm * self.s_std + self.s_mean).cpu(),
            'gate_passed': gate_passed, 'n_steps': n_steps,
            'stop_reason': reason, 'losses': losses
        }
    
    def _median_heuristic(self, X, n=5000):
        # Use seeded RNG for reproducibility
        rng = np.random.default_rng(self.config.seed)
        idx = rng.choice(len(X), min(n, len(X)), replace=False)
        X_sub = X[idx]
        
        sq_norms = np.sum(X_sub ** 2, axis=1)
        dists = sq_norms[:, None] + sq_norms[None, :] - 2 * (X_sub @ X_sub.T)
        
        return 1.0 / (2.0 * np.median(dists[np.triu_indices(len(idx), k=1)]) + 1e-8)
        
    def _build_regressor(self, in_dim, out_dim):
        layers = []
        for h in self.config.hidden_dims:
            layers += [nn.Linear(in_dim, h), nn.ReLU()]
            in_dim = h
        layers.append(nn.Linear(in_dim, out_dim))
        return nn.Sequential(*layers).to(self.device)
    
    def _train_regressor(self, S_norm, Z):
        loader = DataLoader(TensorDataset(S_norm, Z), batch_size=self.config.batch_size, shuffle=True)
        optimizer = torch.optim.AdamW(self.regressor.parameters(), lr=self.config.regressor_lr)
        self.regressor.train()
        losses = []
        for _ in range(self.config.regressor_epochs):
            for s, z in loader:
                optimizer.zero_grad()
                loss = nn.MSELoss()(self.regressor(s.to(self.device)), z.to(self.device))
                loss.backward()
                optimizer.step()
                losses.append(loss.item())
        self.regressor.eval()
        return losses