from dataclasses import dataclass
from typing import Optional, List, Dict
import numpy as np
import torch
import torch.nn as nn

from ..utils.metrics import MMDLoss


__all__ = [
    "NCPPTTAConfig",
    "NCPPTTAAdapter",
]


@dataclass
class NCPPTTAConfig:
    n_samples: int = 1000
    lr: float = 0.1
    tta_steps: int = 200
    tau: Optional[float] = None
    eval_every: int = 5
    patience: int = 10
    consecutive_accept: int = 2
    eps: float = 1e-6
    m_eval: int = 500
    R: int = 3
    verbose: bool = False
    seed: int = 42


class NCPPTTAAdapter:
    def __init__(self, ncpp_sampler, config: NCPPTTAConfig = None, device: str = "cpu"):
        self.config = config or NCPPTTAConfig()
        self.device = torch.device(device)
        self.ncpp_sampler = ncpp_sampler
        self.mmd_loss = MMDLoss().to(self.device)
        self.calibrated_tau = None
    
    def fit(self, X_train: np.ndarray, S_train: np.ndarray,
            calibration_split: float = 0.05, alpha: float = 0.05) -> "NCPPTTAAdapter":
        if calibration_split > 0:
            cfg = self.config
            n_calib = int(len(X_train) * calibration_split)
            if n_calib < 10:
                raise ValueError(f"Need at least 10 calibration samples, got {n_calib}")
            
            rng = np.random.default_rng(cfg.seed)
            calib_idx = rng.permutation(len(X_train))[:n_calib]
            X_calib, S_calib = X_train[calib_idx], S_train[calib_idx]
            
            self._calibrate_tau(X_calib, S_calib, alpha)
        
        return self
    
    def adapt(self, s_init: torch.Tensor, x_obs: np.ndarray, 
              config: NCPPTTAConfig = None) -> Dict:
        cfg = config or self.config
        
        s = s_init.detach().clone().squeeze().to(self.device)
        x_obs_t = torch.as_tensor(x_obs, dtype=torch.float32, device=self.device)
        x_dim = x_obs.shape[-1]
        
        if cfg.tau is not None:
            val = self._eval_mmd(s, x_obs_t, x_dim)
            if val <= cfg.tau:
                if cfg.verbose:
                    print(f"Gate passed! Initial MMD²={val:.6f} ≤ tau={cfg.tau:.6f}")
                return {'best_s': s.cpu(), 'gate_passed': True, 
                        'n_steps': 0, 'stop_reason': 'gate', 'losses': [val]}
        
        if cfg.verbose:
            print(f"Starting optimization (tau={cfg.tau})...")
        
        s = s.clone().requires_grad_(True)
        optimizer = torch.optim.Adam([s], lr=cfg.lr)
        
        best_s = s.detach().clone()
        best_val = float('inf')
        no_improve = 0
        accept_streak = 0
        losses = []
        stop_reason = 'converged'
        
        for t in range(1, cfg.tta_steps + 1):
            optimizer.zero_grad()
            q_samples = self.ncpp_sampler.sample((cfg.n_samples,), s.reshape(1, -1))
            loss = self.mmd_loss(x_obs_t.reshape(-1, x_dim), q_samples.reshape(cfg.n_samples, -1))
            loss.backward()
            optimizer.step()
            
            if t % cfg.eval_every == 0:
                val = self._eval_mmd(s.detach(), x_obs_t, x_dim)
                losses.append(val)
                
                if cfg.verbose:
                    print(f"[{t}/{cfg.tta_steps}] eval_MMD²={val:.6f}, best={best_val:.6f}")
                
                if cfg.tau is not None and val <= cfg.tau:
                    accept_streak += 1
                    if accept_streak >= cfg.consecutive_accept:
                        stop_reason = 'consecutive_accept'
                        if cfg.verbose:
                            print(f"Stopping: consecutive accept ({accept_streak} times)")
                        break
                else:
                    accept_streak = 0
                
                if val < best_val - cfg.eps:
                    best_val = val
                    best_s = s.detach().clone()
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve >= cfg.patience:
                        stop_reason = 'patience'
                        if cfg.verbose:
                            print(f"Stopping: patience exhausted ({no_improve} evals)")
                        break
        
        return {'best_s': best_s.cpu(), 'gate_passed': False,
                'n_steps': t, 'stop_reason': stop_reason, 'losses': losses}
    
    def _eval_mmd(self, s: torch.Tensor, x_obs: torch.Tensor, x_dim: int) -> float:
        cfg = self.config
        vals = []
        with torch.no_grad():
            for _ in range(cfg.R):
                q = self.ncpp_sampler.sample((cfg.n_samples,), s.reshape(1, -1))
                val = self.mmd_loss(x_obs.reshape(-1, x_dim), q.reshape(cfg.n_samples, -1))
                vals.append(val.item())
        return float(np.mean(vals))
    
    def _calibrate_tau(self, X_calib: np.ndarray, S_calib: np.ndarray, alpha: float = 0.05):
        losses = []
        for i in range(len(X_calib)):
            x = torch.as_tensor(X_calib[i], dtype=torch.float32, device=self.device)
            s = torch.as_tensor(S_calib[i], dtype=torch.float32, device=self.device)
            x_dim = X_calib[i].shape[-1]
            losses.append(self._eval_mmd(s, x, x_dim))
        
        self.calibrated_tau = float(np.quantile(losses, 1 - alpha))
        self.config.tau = self.calibrated_tau
        
        print(f"Calibrated tau={self.calibrated_tau:.6f} (alpha={alpha}, n_calib={len(X_calib)})")
        print(f"  Loss stats: mean={np.mean(losses):.6f}, std={np.std(losses):.6f}")