from dataclasses import dataclass, field
from typing import Optional, Dict, Any, List, Tuple, Callable
import copy
import torch
import torch.nn as nn
from tt_sbi.inference.npe import NPETrainer, NPE_TrainConfig
from tt_sbi.inference.nn import get_embedding_net
from tt_sbi.utils.metrics import MMDLoss

@dataclass
class NPE_MMD_TrainConfig(NPE_TrainConfig):
    lambda_reg: float = 1.0
    log_every: int = 50

class MMDNPETrainer(NPETrainer):
    def __init__(self, model: nn.Module, config: NPE_MMD_TrainConfig, device: str = "cpu"):
        super().__init__(model, config, device)
        self.config: NPE_MMD_TrainConfig = config
        self.mmd_loss = None
        self.obs_target = None
        self.embedding_net = None
    
    def setup(self, thetas: torch.Tensor, xs: torch.Tensor, **kwargs) -> None:
        obs_target = kwargs.get("obs_target")
        if obs_target is None:
            raise ValueError("MMDNPETrainer requires 'obs_target' kwarg")
        
        self.obs_target = obs_target.to(self.device)
        self.mmd_loss = MMDLoss().to(self.device)
        self.embedding_net = get_embedding_net(self.model)
    
    def train_step(self, theta_b: torch.Tensor, x_b: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, float]]:
        nll = self.model.loss(theta_b, condition=x_b).mean()
        mmd = self.mmd_loss(self.embedding_net(x_b), self.embedding_net(self.obs_target))
        loss = nll + self.config.lambda_reg * mmd
        
        return loss, {
            "train_nll": nll.item(),
            "train_mmd": mmd.item(),
            "train_total": loss.item(),
        }
    
    def finalize_history(self, history: Dict[str, Any]) -> Dict[str, Any]:
        history["lambda_reg"] = self.config.lambda_reg
        return history

def train_NPE_estimator_mmd(
    model,
    thetas: torch.Tensor,
    xs: torch.Tensor,
    obs_target: torch.Tensor,
    config: NPE_MMD_TrainConfig = NPE_MMD_TrainConfig(),
    device: str = "cpu",
    seed: Optional[int] = None,
):
    trainer = MMDNPETrainer(model, config, device)
    return trainer.train(thetas, xs, seed=seed, obs_target=obs_target)