from dataclasses import dataclass
from typing import Optional, Callable, Dict, Union
import numpy as np
import torch
import pickle
from pathlib import Path

from tt_sbi.ocsvm import OCSVMOutlierDetector, clean_observations


__all__ = [
    "OCSVMTTAConfig",
    "OCSVMTTAAdapter",
]


@dataclass
class OCSVMTTAConfig:
    detector_path: Optional[str] = None
    seed: int = 42
    return_outlier_info: bool = True


class OCSVMTTAAdapter:
    def __init__(
        self, 
        config: Optional[OCSVMTTAConfig] = None,
        detector: Optional[OCSVMOutlierDetector] = None,
        device: str = "cpu"
    ):
        self.config = config or OCSVMTTAConfig()
        self.detector = detector
        self.device = device
        self.summary_fn = None
        self._fitted = False
    
    def fit(
        self,
        X_train: np.ndarray,
        S_train: np.ndarray,
        summary_fn: Optional[Callable] = None,
        **kwargs
    ) -> "OCSVMTTAAdapter":
        if self.detector is None:
            if self.config.detector_path is None:
                raise ValueError(
                    "Must provide either detector or config.detector_path"
                )
            with open(self.config.detector_path, 'rb') as f:
                self.detector = pickle.load(f)
        
        actual_element_shape = tuple(X_train.shape[2:])
        if actual_element_shape != self.detector.element_shape:
            raise ValueError(
                f"Training data element shape {actual_element_shape} doesn't match "
                f"detector element shape {self.detector.element_shape}"
            )
        
        if summary_fn is not None:
            self.summary_fn = summary_fn
        else:
            self.summary_fn = lambda x: x.reshape(x.shape[0], -1) if x.ndim > 2 else x.reshape(-1)
        
        self._fitted = True
        return self
    
    def adapt(
        self,
        s_init: torch.Tensor,
        x_obs: np.ndarray,
        config: Optional[OCSVMTTAConfig] = None,
        **kwargs
    ) -> Dict:
        if not self._fitted:
            raise RuntimeError("Adapter not fitted. Call fit() first.")
        
        cfg = config or self.config
        
        if isinstance(x_obs, torch.Tensor):
            x_obs = x_obs.cpu().numpy()
        x_obs = np.asarray(x_obs)
        
        detector_shape = self.detector.element_shape
        if x_obs.shape[1:] != detector_shape:
            try:
                x_obs = x_obs.reshape(x_obs.shape[0], *detector_shape)
            except ValueError:
                raise ValueError(
                    f"Cannot reshape observations {x_obs.shape} to match "
                    f"detector element shape {detector_shape}"
                )
        
        has_invalid = not np.isfinite(x_obs).all()
        if has_invalid:
            invalid_mask = ~np.isfinite(x_obs).all(axis=tuple(range(1, x_obs.ndim)))
            n_invalid = invalid_mask.sum()
            
            if n_invalid == len(x_obs):
                return {
                    'best_s': s_init.cpu() if isinstance(s_init, torch.Tensor) else torch.tensor(s_init),
                    'gate_passed': False,
                    'n_steps': 0,
                    'stop_reason': 'all_invalid',
                    'n_outliers': len(x_obs),
                    'outlier_frac': 1.0,
                    'losses': [],
                }
            
            valid_obs = x_obs[~invalid_mask]
            x_obs_fixed = x_obs.copy()
            x_obs_fixed[invalid_mask] = valid_obs[np.random.choice(len(valid_obs), n_invalid)]
            x_obs = x_obs_fixed
        
        x_clean, outlier_mask, scores = clean_observations(
            self.detector, x_obs, seed=cfg.seed
        )
        
        if has_invalid:
            outlier_mask = outlier_mask | invalid_mask
        
        n_outliers = int(outlier_mask.sum())
        n_obs = len(x_obs)
        
        x_clean_batch = x_clean[np.newaxis, ...]
        
        x_clean_tensor = torch.from_numpy(x_clean_batch).float()
        
        with torch.no_grad():
            s_clean = self.summary_fn(x_clean_tensor)
        
        if isinstance(s_clean, np.ndarray):
            s_clean = torch.from_numpy(s_clean).float()
        
        if s_clean.ndim == 2 and s_clean.shape[0] == 1:
            s_clean = s_clean.squeeze(0)
        
        result = {
            'best_s': s_clean.cpu(),
            'gate_passed': False, 
            'n_steps': 0,         
            'stop_reason': 'ocsvm_clean',
            'n_outliers': n_outliers,
            'outlier_frac': n_outliers / n_obs,
            'losses': [],        
        }
        
        if cfg.return_outlier_info:
            result['outlier_mask'] = outlier_mask
            result['scores'] = scores
        
        return result
    
    @classmethod
    def from_detector_path(
        cls,
        detector_path: Union[str, Path],
        device: str = "cpu",
        **config_kwargs
    ) -> "OCSVMTTAAdapter":
        config = OCSVMTTAConfig(
            detector_path=str(detector_path),
            **config_kwargs
        )
        return cls(config=config, device=device)
