import torch
from dataclasses import dataclass
from typing import Optional
from pathlib import Path

from sbi.neural_nets import likelihood_nn

def load_ncpp_model(model_path: Path, device: str = "cpu"):
    checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
    cfg = checkpoint["config"]
    
    builder = likelihood_nn(
        model="nsf",
        hidden_features=cfg.get("ncpp_hidden_features", 64),
        num_transforms=cfg.get("ncpp_num_transforms", 5),
        num_bins=10,
    )
    
    x_dim = cfg["x_dim"]
    s_dim = cfg["embedding_dim"]
    dummy_x = torch.zeros(2, x_dim)
    dummy_s = torch.zeros(2, s_dim)
    
    ncpp_model = builder(dummy_x, dummy_s)
    ncpp_model.load_state_dict(checkpoint["model_state_dict"])
    ncpp_model = ncpp_model.to(device)
    ncpp_model.eval()
    
    return ncpp_model, cfg 


def find_models(models_dir: Path, dim: int, n_obs: int) -> dict:
    models = {}
    base = f"npe_d{dim}_n{n_obs}"
    
    if (models_dir / f"{base}.pt").exists():
        models["npe"] = models_dir / f"{base}.pt"
    
    if (models_dir / f"{base}_noisy.pt").exists():
        models["npe_noisy"] = models_dir / f"{base}_noisy.pt"
    
    for f in models_dir.glob(f"{base}_rs_lambda_*.pt"):
        if "_contam_" not in f.name:
            lam = f.stem.split("_rs_lambda_")[1]
            models[f"npe_rs_{lam}"] = f
    
    return models


def spike_and_slab_noise(
    x: torch.Tensor,
    slab_scale: float,
    spike_scale: float = 0.01,
    spike_prob: float = 0.5,
    generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
    device = x.device
    dtype = x.dtype
    
    is_slab = torch.bernoulli(
        torch.full(x.shape, 1 - spike_prob, device=device, dtype=dtype),
        generator=generator
    )
    
    spike = torch.randn(x.shape, device=device, dtype=dtype, generator=generator) * spike_scale
    slab = torch.empty(x.shape, device=device, dtype=dtype).cauchy_(generator=generator) * slab_scale
    noise = (1 - is_slab) * spike + is_slab * slab
    
    return x + noise


class SpikeAndSlabTransform:
    def __init__(
        self,
        slab_scale: float,
        spike_scale: float = 0.01,
        spike_prob: float = 0.5,
    ):
        self.slab_scale = slab_scale
        self.spike_scale = spike_scale
        self.spike_prob = spike_prob
    
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return spike_and_slab_noise(
            x,
            slab_scale=self.slab_scale,
            spike_scale=self.spike_scale,
            spike_prob=self.spike_prob,
        )
    
    def __repr__(self):
        return (f"SpikeAndSlabTransform(slab_scale={self.slab_scale}, "
                f"spike_scale={self.spike_scale}, spike_prob={self.spike_prob})")

def get_misspec_suffix(misspec_type: str, **kwargs) -> str:
    if misspec_type == "none":
        return "well_specified"
    elif misspec_type == "prior_location":
        return f"prior_loc_{kwargs['prior_location_shift']}"
    elif misspec_type == "prior_scale":
        return f"prior_scale_{kwargs['prior_scale_factor']}"
    elif misspec_type == "likelihood_scale":
        return f"lik_scale_{kwargs['likelihood_scale_factor']}"
    elif misspec_type == "contamination":
        eps = kwargs.get('contamination_eps', kwargs.get('epsilon', 0.0))
        shift = kwargs.get('contamination_shift', 0.0)
        if shift and shift != 0.0:
            return f"contam_{eps}_shift_{shift}"
        else:
            return f"contam_{eps}"
    else:
        raise ValueError(f"Unknown misspecification type: {misspec_type}")

def get_test_data_path(data_dir: Path, d: int, n: int, misspec_type: str, **kwargs) -> Path:
    suffix = get_misspec_suffix(misspec_type, **kwargs)
    return data_dir / f"test_d{d}_n{n}_{suffix}.pt"