from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple
import torch
import torch.nn as nn

from .npe import NPETrainer, NPE_TrainConfig
from tt_sbi.utils.misc import SpikeAndSlabTransform 

@dataclass
class NoisyNPE_TrainConfig(NPE_TrainConfig):
    slab_scale: float = 0.2       # Cauchy noise scale (outliers)
    spike_scale: float = 0.01    # Gaussian noise scale (clean)
    spike_prob: float = 0.5      # Prob of clean vs outlier
    noise_on_val: bool = False


class NoisyNPETrainer(NPETrainer):
    def __init__(self, model: nn.Module, config: NoisyNPE_TrainConfig, device: str = "cpu"):
        super().__init__(model, config, device)
        self.noise_transform = None
    
    def setup(self, thetas, xs, **kwargs):
        self.noise_transform = kwargs.get("noise_transform")
        if self.noise_transform is None:
            self.noise_transform = SpikeAndSlabTransform(
                slab_scale=self.config.slab_scale,
                spike_scale=self.config.spike_scale,
                spike_prob=self.config.spike_prob,
            )
    
    def train_step(self, theta_b, x_b):
        x_b_noisy = self.noise_transform(x_b)
        loss = self.model.loss(theta_b, condition=x_b_noisy).mean()
        return loss, {"train_loss": loss.item()}
    
    def validate_batch(self, theta_b, x_b):
        if self.config.noise_on_val:
            x_b = self.noise_transform(x_b)
        return self.model.loss(theta_b, condition=x_b).sum().item()


def train_NPE_estimator_noisy(model, thetas, xs, config=None, device="cpu", seed=None, noise_transform=None):
    trainer = NoisyNPETrainer(model, config or NoisyNPE_TrainConfig(), device)
    return trainer.train(thetas, xs, seed=seed, noise_transform=noise_transform)