import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import (
    Categorical, Normal, Independent, MixtureSameFamily,
    LowRankMultivariateNormal, MultivariateNormal
)


class GMM4PR(nn.Module):
    def __init__(self, K, latent_dim, device, T_pi=1.0, T_mu=1.0, T_sigma=1.0,
                 T_shared=1.0, logstd_bounds=(-3.0, 1.0)):
        super().__init__()
        self.K = K
        self.latent_dim = latent_dim
        self.device = device 

        self.T_pi = T_pi
        self.T_mu = T_mu
        self.T_sigma = T_sigma
        self.T_shared = T_shared
        self.logstd_bounds = logstd_bounds
        self.budget = {"norm": "linf", "eps": 8/255}

        self.reg_coeffs = {
            'pi_entropy': 0.01,
            'mean_diversity': 0.001,
        }

        self.feat_extractor = None
        self.shared_trunk = None

        self.y_emb = None
        self.y_emb_normalize = True
        self.num_cls = None

        self.up_sampler = None

    def _make_shared_trunk(self, in_dim, h_dim):
        return nn.Sequential(
            nn.Linear(in_dim, h_dim),
            nn.BatchNorm1d(h_dim),
            nn.ReLU(),
        ).to(self.device)

    def _make_head(self, h_dim, out_dim):
        return nn.Linear(h_dim, out_dim).to(self.device)

    def set_y_embedding(self, num_cls, y_dim, normalize=True):
        self.y_emb = nn.Embedding(num_cls, y_dim).to(self.device)
        self.y_emb_normalize = normalize
        self.num_cls = num_cls

    def set_temperatures(self, T_pi=None, T_mu=None, T_sigma=None, T_shared=None):
        if T_pi is not None:
            self.T_pi = T_pi
        if T_mu is not None:
            self.T_mu = T_mu
        if T_sigma is not None:
            self.T_sigma = T_sigma
        if T_shared is not None:
            self.T_shared = T_shared

    def set_regularization(self, **coeffs):
        self.reg_coeffs.update(coeffs)

    def set_condition(self, cond_mode, cov_type, cov_rank, feat_dim, num_cls, hidden_dim):
        self.cond_mode = cond_mode
        self.hidden_dim = hidden_dim

        self.feat_dim = feat_dim
        y_dim = (self.y_emb.embedding_dim if (self.y_emb is not None) else num_cls)
        self.num_cls = num_cls

        self.cov_type = cov_type
        self.cov_rank = cov_rank

        if cond_mode is None:
            self.pi = nn.Parameter(torch.randn(self.K, device=self.device) * 0.01) 
            self.mu = nn.Parameter(torch.zeros(self.K, self.latent_dim, device=self.device))
            self._init_cov_params(cov_type, cov_rank)

        elif cond_mode == "x":
            self.shared_trunk = self._make_shared_trunk(feat_dim, hidden_dim)

            self.pi = self._make_head(hidden_dim, self.K)
            self.mu = self._make_head(hidden_dim, self.K * self.latent_dim)
            self._init_cov_heads(cov_type, cov_rank, hidden_dim)

        elif cond_mode == "xy":
            self.shared_trunk = self._make_shared_trunk(feat_dim, hidden_dim)

            self.pi = self._make_head(y_dim, self.K)

            self.mu = self._make_head(hidden_dim, self.K * self.latent_dim)
            self._init_cov_heads(cov_type, cov_rank, hidden_dim)

        elif cond_mode == "y":
            assert self.shared_trunk is None, "For cond_mode='y', shared_trunk should be None"

            self.pi = self._make_head(y_dim, self.K)

            self.mu = nn.Parameter(torch.zeros(self.K, self.latent_dim, device=self.device))
            self._init_cov_params(cov_type, cov_rank)

        else:
            raise ValueError(f"cond_mode must be x/y/xy/none, got {cond_mode}")

        total_params = sum(p.numel() for p in self.parameters())
        if self.cond_mode in ["x", "xy"]:
            trunk_params = sum(p.numel() for p in self.shared_trunk.parameters())
            pi_params = sum(p.numel() for p in self.pi.parameters())
            mu_params = sum(p.numel() for p in self.mu.parameters())

        elif self.cond_mode == "y":
            trunk_params = 0
            pi_params = sum(p.numel() for p in self.pi.parameters())
            mu_params = self.mu.numel()

        else:
            trunk_params = 0
            pi_params = self.pi.numel()
            mu_params = self.mu.numel()

        cov_params = total_params - pi_params - mu_params
        print(f"[Params] Shared trunk: {trunk_params:,} | pi: {pi_params:,} | mu: {mu_params:,}, | cov: {cov_params:,} | Total: {total_params:,}")

    def _init_cov_params(self, cov_type, cov_rank):
        if cov_type == "diag":
            self.log_sigma = nn.Parameter(torch.zeros(self.K, self.latent_dim, device=self.device))

        elif cov_type == "lowrank":
            self.log_sigma = nn.Parameter(torch.zeros(self.K, self.latent_dim, device=self.device))
            self.U = nn.Parameter(torch.zeros(self.K, self.latent_dim, cov_rank, device=self.device))

        elif cov_type == "full":
            self.L_raw = nn.Parameter(torch.zeros(self.K, self.latent_dim, self.latent_dim, device=self.device))
        else:
            raise ValueError(f"cov_type must be diag/full/lowrank, got {cov_type}")

    def _init_cov_heads(self, cov_type, cov_rank, hidden_dim):
        if cov_type == "diag":
            self.logsig = self._make_head(hidden_dim, self.K * self.latent_dim)
        elif cov_type == "lowrank":
            self.logsig = self._make_head(hidden_dim, self.K * self.latent_dim)
            self.U = self._make_head(hidden_dim, self.K * self.latent_dim * cov_rank)
        elif cov_type == "full":
            self.L = self._make_head(hidden_dim, self.K * self.latent_dim * self.latent_dim)
        else:
            raise ValueError(f"cov_type must be diag/full/lowrank, got {cov_type}")

    def set_feat_extractor(self, feat_extractor):
        self.feat_extractor = feat_extractor.to(self.device)
        for p in self.feat_extractor.parameters():
            p.requires_grad = False

    def set_up_sampler(self, up_sampler):
        self.up_sampler = up_sampler.to(self.device)

    def set_budget(self, norm="linf", eps=8/255):
        self.budget = {"norm": norm, "eps": float(eps)}

    def _make_condition(self, x=None, y=None):
        mode = self.cond_mode
        part_x = None
        part_y = None

        B = x.size(0) if x is not None else (y.size(0) if y is not None else 1)

        if mode is None:
            part_x = part_y = torch.ones(B, 1, device=self.device)

        elif mode == "x":
            if self.feat_extractor is None:
                raise ValueError(f"cond_mode={mode} requires feat_extractor")
            if x is None:
                raise ValueError(f"cond_mode={mode} requires x input")
            with torch.no_grad():
                part_x = self.feat_extractor(x).view(x.size(0), -1)
            part_y = torch.ones(B, 1, device=self.device)

        elif mode == "xy":
            if self.feat_extractor is None:
                raise ValueError(f"cond_mode={mode} requires feat_extractor")
            if x is None or y is None:
                raise ValueError(f"cond_mode={mode} requires both x and y inputs")
            with torch.no_grad():
                part_x = self.feat_extractor(x).view(x.size(0), -1)

            if self.num_cls is None:
                raise ValueError("self.num_cls must be set for xy-conditioning")

            if self.y_emb is not None:
                yvec = self.y_emb(y)
                if self.y_emb_normalize:
                    yvec = F.normalize(yvec, dim=-1)
            else:
                yvec = F.one_hot(y, num_classes=self.num_cls).float().to(self.device)
            part_y = yvec

        elif mode == "y":
            if y is None:
                raise ValueError(f"cond_mode={mode} requires y input")
            if self.num_cls is None:
                raise ValueError("self.num_cls must be set for y-conditioning")

            if self.y_emb is not None:
                yvec = self.y_emb(y)
                if self.y_emb_normalize:
                    yvec = F.normalize(yvec, dim=-1)
            else:
                yvec = F.one_hot(y, num_classes=self.num_cls).float().to(self.device)

            part_x = torch.ones(B, 1, device=self.device)
            part_y = yvec
        else:
            raise ValueError(f"cond_mode must be x/y/xy/none, got {mode}")

        if (part_x is None) or (part_y is None):
            raise ValueError(f"No enough conditioning created for mode={mode}")

        return part_x, part_y

    def _decode_latent(self, eps, out_shape):
        if self.up_sampler is None:
            assert eps.size(-1) == np.prod(out_shape), \
                f"Latent vector size {eps.size(-1)} does not match output shape {out_shape}"
            return eps.view(eps.shape[:-1] + out_shape)

        if eps.dim() == 2:
            u = self.up_sampler(eps)
            return u.view(eps.size(0), *out_shape) if u.dim() == 2 else u
        elif eps.dim() == 3:
            S, B, D = eps.shape
            u = self.up_sampler(eps.reshape(-1, D))
            if u.dim() == 2:
                return u.view(S, B, *out_shape)
            return u.view(S, B, *u.shape[1:])
        else:
            raise ValueError(f"eps must be 2D or 3D, got shape {eps.shape}")

    def _project_to_budget(self, u):
        norm = self.budget["norm"].lower()
        eps = float(self.budget["eps"])

        if norm == "linf":
            return eps * torch.tanh(u)
        elif norm == "l2":
            if u.dim() == 4:
                flat = u.view(u.size(0), -1)
                n = flat.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
                return (eps * flat / n).view_as(u)
            elif u.dim() == 5:
                flat = u.view(u.shape[0], u.shape[1], -1)
                n = flat.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
                return (eps * flat / n).view_as(u)

        raise ValueError(f"Unsupported norm={norm} or shape={u.shape}")

    def _build_dist(self, part_x, part_y):
        B = part_x.size(0)
        if part_y is not None and part_y.size(0) != B:
            raise ValueError(f"Batch size mismatch between x_cond ({B}) and y_cond ({part_y.size(0)})")
        K, D = self.K, self.latent_dim

        if self.cond_mode is None:
            assert self.shared_trunk is None, "Unconditional GMM should not have shared trunk"
            h_shared = None

            assert all(isinstance(getattr(self, name, None), nn.Parameter) for name in ("pi","mu")), \
                "Unconditional GMM requires pi and mu to be parameters"
            pi_logits = self.pi.unsqueeze(0).expand(B, -1) * self.T_pi
            mu = self.mu.unsqueeze(0).expand(B, K, D) * self.T_mu

        elif self.cond_mode == "x":
            assert self.shared_trunk is not None, "cond_mode='x' requires shared_trunk"
            h_shared = self.shared_trunk(part_x) * self.T_shared

            assert all(callable(getattr(self, name, None)) for name in ("pi", "mu")), \
                "Conditional GMM of x requires pi and mu to be mappings"
            pi_logits = self.pi(h_shared) * self.T_pi
            mu_flat = self.mu(h_shared) * self.T_mu
            mu = mu_flat.view(B, K, D)

        elif self.cond_mode == "xy":
            assert self.shared_trunk is not None, "cond_mode='xy' requires shared_trunk"
            h_shared = self.shared_trunk(part_x) * self.T_shared

            assert all(callable(getattr(self, name, None)) for name in ("pi", "mu")), \
                "Conditional GMM of xy requires pi and mu to be mappings"
            pi_logits = self.pi(part_y) * self.T_pi
            mu_flat = self.mu(h_shared) * self.T_mu
            mu = mu_flat.view(B, K, D)

        elif self.cond_mode == "y":
            assert self.shared_trunk is None, "conditional GMM of y should not have shared trunk"
            h_shared = None

            assert callable(getattr(self, "pi", None)), "Conditional GMM of y requires pi to be a mapping"
            pi_logits = self.pi(part_y) * self.T_pi

            assert isinstance(self.mu, nn.Parameter), "Conditional GMM of y requires mu to be a parameter"
            mu = self.mu.unsqueeze(0).expand(B, K, D) * self.T_mu

        else:
            raise ValueError(f"Unknown cond_mode: {self.cond_mode}")

        mix = Categorical(logits=pi_logits)
        cache = {"pi_logits": pi_logits, "mu": mu}

        if self.cov_type == "diag":
            log_std = self._get_log_std(h_shared, B, K, D)
            log_std = log_std + torch.log(torch.tensor(self.T_sigma, device=part_x.device))
            std = torch.exp(log_std)
            comp = Independent(Normal(mu, std), 1)
            cache["log_std"] = log_std

        elif self.cov_type == "lowrank":
            log_std = self._get_log_std(h_shared, B, K, D)
            log_std = log_std + torch.log(torch.tensor(self.T_sigma, device=part_x.device))
            std2 = torch.exp(2.0 * log_std)
            
            U = self._get_U(h_shared, B, K, D)
            U = U * np.sqrt(self.T_sigma)
            
            comp = LowRankMultivariateNormal(loc=mu, cov_factor=U, cov_diag=std2)
            cache.update({"log_std": log_std, "U": U})

        elif self.cov_type == "full":
            L = self._get_cholesky(h_shared, B, K, D)
            L = L * np.sqrt(self.T_sigma)
            comp = MultivariateNormal(loc=mu, scale_tril=L)
            cache["L"] = L

        else:
            raise ValueError(f"Unknown cov_type: {self.cov_type}")

        dist = MixtureSameFamily(mixture_distribution=mix, component_distribution=comp)
        return dist, cache

    def _get_log_std(self, h_shared, B, K, D):
        """Get clamped log standard deviation."""
        if callable(getattr(self, "logsig", None)):
            log_std = self.logsig(h_shared).view(B, K, D)
        else:
            log_std = self.log_sigma.unsqueeze(0).expand(B, K, D)
        
        lo, hi = self.logstd_bounds
        return torch.clamp(log_std, lo, hi)

    def _get_U(self, h_shared, B, K, D):
        """Get low-rank factor U."""
        if callable(getattr(self, "U", None)):
            U = self.U(h_shared).view(B, K, D, self.cov_rank)
        else:
            U = self.U.unsqueeze(0).expand(B, K, D, self.cov_rank)
        return U

    def _get_cholesky(self, h_shared, B, K, D):
        """Get Cholesky factor for full covariance."""
        if callable(getattr(self, "L", None)):
            L_raw = self.L(h_shared).view(B, K, D, D)
        else:
            L_raw = self.L_raw.unsqueeze(0).expand(B, K, D, D)

        tril_mask = torch.tril(torch.ones(D, D, device=self.device, dtype=torch.bool))
        L = torch.zeros_like(L_raw)
        L[..., tril_mask] = L_raw[..., tril_mask]

        diag_idx = torch.arange(D, device=self.device)
        L[..., diag_idx, diag_idx] = F.softplus(L[..., diag_idx, diag_idx]) + 1e-4
        
        return L

    def compute_regularization(self, cache):
        reg_terms = {}
        pi_logits = cache['pi_logits']
        mu = cache['mu']
        B, K, D = mu.shape

        pi_probs = F.softmax(pi_logits, dim=-1)
        pi_entropy = Categorical(probs=pi_probs).entropy()
        norm_entropy_loss = (1.0 - pi_entropy / torch.log(torch.tensor(pi_probs.size(-1), dtype=torch.float32)))

        reg_terms['pi_entropy'] = norm_entropy_loss.mean()

        K = mu.size(1)

        mu_hat = F.normalize(mu, p=2, dim=-1, eps=1e-8)

        gram = torch.bmm(mu_hat, mu_hat.transpose(1, 2))
        gram = gram + 1e-6 * torch.eye(K, device=mu.device, dtype=gram.dtype).unsqueeze(0)

        sign, logabsdet = torch.linalg.slogdet(gram)

        target_diversity = torch.log(torch.tensor(float(K), device=mu.device))
        diversity_loss = F.relu(target_diversity - logabsdet).mean()

        if self.training and torch.rand(1).item() < 0.05:
            if (sign < 0).any():
                print(f"[Warning] Negative det: sign={sign.min().item()}")

        reg_terms['mean_diversity'] = diversity_loss

        return reg_terms

    def forward(self, x=None, y=None):
        part_x, part_y = self._make_condition(x=x, y=y)
        dist, cache = self._build_dist(part_x, part_y)
        return {"dist": dist, "cache": cache}

    def _rsample_from_gmm(self, cache, num_samples, temperature=1.0):
        pi_logits = cache['pi_logits']
        mu = cache['mu']
        B, K, D = mu.shape

        gumbel_noise = -torch.log(-torch.log(
            torch.rand(num_samples, B, K, device=pi_logits.device) + 1e-20
        ) + 1e-20)

        logits_with_gumbel = (pi_logits.unsqueeze(0) + gumbel_noise) / temperature

        soft_component_weights = F.softmax(logits_with_gumbel, dim=-1)

        eps_std = torch.randn(num_samples, B, K, D, device=mu.device)

        if self.cov_type == "diag":
            log_std = cache['log_std']
            std = torch.exp(log_std)

            mu_expanded = mu.unsqueeze(0)
            std_expanded = std.unsqueeze(0)

            component_samples = mu_expanded + std_expanded * eps_std

        elif self.cov_type == "lowrank":
            log_std = cache['log_std']
            U = cache['U']

            std = torch.exp(log_std)

            eta = torch.randn(num_samples, B, K, self.cov_rank, device=mu.device)

            mu_expanded = mu.unsqueeze(0)
            std_expanded = std.unsqueeze(0)
            U_expanded = U.unsqueeze(0)

            lowrank_term = torch.matmul(U_expanded, eta.unsqueeze(-1)).squeeze(-1)

            eps_diag = torch.randn(num_samples, B, K, D, device=mu.device)

            component_samples = mu_expanded + lowrank_term + std_expanded * eps_diag

        elif self.cov_type == "full":
            L = cache['L']

            mu_expanded = mu.unsqueeze(0)
            L_expanded = L.unsqueeze(0)

            component_samples = mu_expanded + torch.matmul(
                L_expanded, eps_std.unsqueeze(-1)
            ).squeeze(-1)

        else:
            raise ValueError(f"Unknown cov_type: {self.cov_type}")

        weights = soft_component_weights.unsqueeze(-1)
        samples = (component_samples * weights).sum(dim=2)

        return samples

    def _sample_and_classify(self, x, num_samples, classifier, cache, temperature=1.0):
        B = x.size(0)

        eps = self._rsample_from_gmm(cache, num_samples, temperature=temperature)

        u = self._decode_latent(eps, out_shape=x.shape[1:])

        delta = self._project_to_budget(u)
        x_rep = x.unsqueeze(0).expand_as(delta)

        logits = classifier((x_rep + delta).flatten(0, 1))

        logits = logits.view(num_samples, B, -1)

        return logits


    def pr_loss(self, x, y, classifier, num_samples=8, loss_variant="cw", kappa=0.0,
                chunk_size=None, return_reg_details=False, gumbel_temperature=1.0):
        out = self.forward(x=x, y=y)

        cache = out["cache"]
        B = x.size(0)

        if chunk_size is None:
            max_batch = 32
            chunk_size = max(1, max_batch // B)

        if num_samples <= chunk_size:
            logits = self._sample_and_classify(x, num_samples, classifier, cache, gumbel_temperature)
        else:
            logits_list = []
            num_chunks = (num_samples + chunk_size - 1) // chunk_size

            for i in range(num_chunks):
                chunk_samples = min(chunk_size, num_samples - i * chunk_size)
                logits_chunk = self._sample_and_classify(x, chunk_samples, classifier, cache, gumbel_temperature)
                logits_list.append(logits_chunk)

            logits = torch.cat(logits_list, dim=0)

        logits = logits - logits.max(dim=-1, keepdim=True).values
        if loss_variant == "cw":
            y_rep = y.unsqueeze(0).expand(num_samples, -1)
            logit_y = logits.gather(-1, y_rep.unsqueeze(-1)).squeeze(-1)
            mask = F.one_hot(y_rep, logits.size(-1)).bool()
            max_others = logits.masked_fill(mask, float("-inf")).max(-1).values

            margin = logit_y - max_others + kappa
            main_loss = F.softplus(margin).mean()

        else:
            main_loss = 1 - F.cross_entropy(
                logits.flatten(0, 1),
                y.unsqueeze(0).expand(num_samples, -1).flatten()
            )

        reg_terms = self.compute_regularization(cache)
        total_reg = sum(self.reg_coeffs.get(k, 0.0) * v for k, v in reg_terms.items())

        total_loss = main_loss + total_reg

        predictions = logits.argmax(dim=-1)
        pr = self.compute_pr(predictions, y, reduction='mean').item()

        result = {
            "pr": pr,
            "loss": total_loss,
            "main": main_loss.detach(),
            "reg": total_reg.detach(),
        }

        if return_reg_details:
            result["reg_details"] = {k: v.detach().item() for k, v in reg_terms.items()}
            result["pi_probs"] = F.softmax(cache['pi_logits'], dim=-1).mean(dim=0).detach()

        return result

    @staticmethod
    def compute_pr(predictions, y, reduction='mean'):
        if predictions.dim() == 1:
            S_times_B = predictions.size(0)
            B = y.size(0)
            if S_times_B % B != 0:
                raise ValueError(f"predictions size {S_times_B} not divisible by y size {B}")
            S = S_times_B // B
            predictions = predictions.view(S, B)
        elif predictions.dim() == 2:
            S, B = predictions.shape
            if y.size(0) != B:
                raise ValueError(f"predictions batch size {B} != y size {y.size(0)}")
        else:
            raise ValueError(f"predictions must be 1D or 2D, got shape {predictions.shape}")

        y_expanded = y.unsqueeze(0).expand(S, -1)

        success = predictions.eq(y_expanded).float()

        per_image_pr = success.mean(dim=0)

        if reduction == 'mean':
            return per_image_pr.mean()
        elif reduction == 'sum':
            return per_image_pr.sum()
        elif reduction == 'none':
            return per_image_pr
        else:
            raise ValueError(f"Unknown reduction: {reduction}. Use 'mean', 'sum', or 'none'.")

    @torch.no_grad()
    def evaluate_pr(self, x, y, classifier, num_samples=100,
                    use_soft_sampling=False, temperature=1.0, reduction='none',
                    chunk_size=None):
        B = x.size(0)

        if chunk_size is None:
            max_batch = 32
            chunk_size = max(1, max_batch // B)

        if num_samples <= chunk_size:
            predictions = self._evaluate_chunk(
                x, y, num_samples, classifier,
                use_soft_sampling, temperature)
        else:
            predictions_list = []
            num_chunks = (num_samples + chunk_size - 1) // chunk_size

            for i in range(num_chunks):
                chunk_samples = min(chunk_size, num_samples - i * chunk_size)
                predictions_chunk = self._evaluate_chunk(
                    x, y, chunk_samples, classifier,
                    use_soft_sampling, temperature)
                predictions_list.append(predictions_chunk)

            predictions = torch.cat(predictions_list, dim=0)

        pr = self.compute_pr(predictions, y, reduction=reduction)

        return pr

    def _evaluate_chunk(self, x, y, num_samples, classifier,
                       use_soft_sampling, temperature):
        B = x.size(0)

        forward_out = self.forward(x=x, y=y)

        if use_soft_sampling:
            cache = forward_out["cache"]
            eps = self._rsample_from_gmm(cache, num_samples, temperature=temperature)
        else:
            dist = forward_out["dist"]
            eps = dist.sample((num_samples,))

        u = self._decode_latent(eps, out_shape=x.shape[1:])

        delta = self._project_to_budget(u)

        x_rep = x.unsqueeze(0).expand_as(delta)

        logits = classifier((x_rep + delta).flatten(0, 1))

        predictions = logits.argmax(dim=-1).view(num_samples, B)

        return predictions

    @torch.no_grad()
    def sample(self, x=None, y=None, num_samples=1, out_shape=None, chunk_size=None):
        if self.cond_mode in ("x", "xy") and x is None:
            raise ValueError(
                f"GMM trained with cond_mode='{self.cond_mode}' requires x input.\n"
                f"Provide x or train a model with different conditioning."
            )
        if self.cond_mode in ("y", "xy") and y is None:
            raise ValueError(
                f"GMM trained with cond_mode='{self.cond_mode}' requires y input.\n"
                f"Provide y or train a model with different conditioning."
            )
        
        out = self.forward(x=x, y=y)
        dist = out["dist"]
        
        if x is not None:
            out_shape = x.shape[1:]  
        elif out_shape is None:
            raise ValueError(
                "out_shape must be provided when x is None.\n"
                "Pass out_shape=(C, H, W) explicitly, e.g., out_shape=(3, 32, 32)"
            )

        use_chunking = (chunk_size is not None and num_samples > chunk_size)

        if not use_chunking:
            eps = dist.sample((num_samples,))
            u = self._decode_latent(eps, out_shape=out_shape)
            delta = self._project_to_budget(u)
            return {"eps": eps, "u": u, "delta": delta}

        eps_list, u_list, delta_list = [], [], []
        num_chunks = (num_samples + chunk_size - 1) // chunk_size

        for i in range(num_chunks):
            chunk_samples = min(chunk_size, num_samples - i * chunk_size)

            eps_chunk = dist.sample((chunk_samples,))
            eps_list.append(eps_chunk)

            u_chunk = self._decode_latent(eps_chunk, out_shape=out_shape)
            delta_chunk = self._project_to_budget(u_chunk)
            u_list.append(u_chunk)
            delta_list.append(delta_chunk)

        eps = torch.cat(eps_list, dim=0)
        u = torch.cat(u_list, dim=0)
        delta = torch.cat(delta_list, dim=0)

        return {"eps": eps, "u": u, "delta": delta}

    def save(self, path, extra=None):
        cfg = dict(
            K=self.K,
            latent_dim=self.latent_dim,

            T_pi=self.T_pi,
            T_mu=self.T_mu,
            T_sigma=self.T_sigma,
            T_shared=self.T_shared,
            logstd_bounds=self.logstd_bounds,
            budget=self.budget,

            reg_coeffs=self.reg_coeffs,

            has_y_emb=(self.y_emb is not None),
            y_emb_dim=(self.y_emb.embedding_dim if self.y_emb is not None else None),
            y_emb_normalize=self.y_emb_normalize,

            cond_mode=self.cond_mode,
            cov_type=self.cov_type,
            cov_rank=self.cov_rank,
            feat_dim=self.feat_dim,
            num_cls=self.num_cls,
            hidden_dim=self.hidden_dim,
        )
        if extra:
            cfg.update(extra)
        torch.save({"state_dict": self.state_dict(), "config": cfg}, path)
        print(f"Model saved to {path}")

    @classmethod
    def load_from_checkpoint(cls, path, feat_extractor=None, up_sampler=None,
                            map_location="cpu", strict=True):
        ckpt = torch.load(path, map_location=map_location)
        cfg = ckpt["config"]

        model = cls(
            K=cfg["K"],
            latent_dim=cfg["latent_dim"],
            device=map_location,
            T_pi=cfg.get("T_pi", 1.0),
            T_mu=cfg.get("T_mu", 1.0),
            T_sigma=cfg.get("T_sigma", 1.0),
            T_shared=cfg.get("T_shared", 1.0),
            logstd_bounds=cfg.get("logstd_bounds", (-3.0, 1.0))
        )

        if cfg.get("has_y_emb", False):
            model.set_y_embedding(
                num_cls=cfg["num_cls"],
                y_dim=cfg["y_emb_dim"],
                normalize=cfg.get("y_emb_normalize", True)
            )

        model.set_condition(
            cond_mode=cfg["cond_mode"],
            cov_type=cfg["cov_type"],
            cov_rank=cfg.get("cov_rank", 0),
            feat_dim=cfg["feat_dim"],
            num_cls=cfg["num_cls"],
            hidden_dim=cfg["hidden_dim"]
        )

        if any("feat_extractor" in k for k in ckpt["state_dict"].keys()):
            if feat_extractor is not None:
                model.set_feat_extractor(feat_extractor)
        if up_sampler is not None:
            model.set_up_sampler(up_sampler)

        model.set_budget(**cfg.get("budget", {"norm": "linf", "eps": 8/255}))
        model.set_regularization(**cfg.get("reg_coeffs", {}))

        model.load_state_dict(ckpt["state_dict"], strict=strict)

        if not any("feat_extractor" in k for k in ckpt["state_dict"].keys()):
            if feat_extractor is not None:
                model.set_feat_extractor(feat_extractor)


        print(f"Model loaded from {path}")
        return model