import torch
import torch.nn as nn
import torch.nn.functional as F

class ConceptBottleneckIB(nn.Module):
    def __init__(self, concept_embs, num_class, classifier, hidden_dim=64, temperature=1.0, beta=1.0, prior_logit=-2.2):
        super().__init__()
        self.k, self.d = concept_embs.shape  # k = # concepts, d = concept embedding dim
        self.register_buffer("concept_embs", concept_embs)  # [k, d]

        # Shared MLP to map concept embedding → gate logit
        self.gate_predictor = nn.Sequential(
            nn.Linear(self.d, 1))
        # self.gate_predictor = nn.Sequential(
        #     nn.Linear(self.d, hidden_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, 1)
        # )

        self.temperature = temperature
        self.beta = beta
        self.prior_logit = prior_logit  # corresponds to sparse prior, e.g., sigmoid(-2.2) ≈ 0.1

        # Final label predictor (over gated concept vector)
        # self.classifier = nn.Sequential(
        #     nn.Linear(self.k, 64),
        #     nn.ReLU(),
        #     nn.Linear(64, num_class)
        # )
        self.classifier = classifier
        
    def sample_gumbel_sigmoid(self, logits):
        u = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(u + 1e-10) + 1e-10)
        return torch.sigmoid((logits + gumbel_noise) / self.temperature)

    def kl_bernoulli(self, logits, prior_logit):
        p = torch.sigmoid(logits)
        q = torch.sigmoid(prior_logit)
        kl = p * (torch.log(p + 1e-10) - torch.log(q + 1e-10)) + \
             (1 - p) * (torch.log(1 - p + 1e-10) - torch.log(1 - q + 1e-10))
        return kl.sum()
        
    def forward(self, C, labels, class_weight=None):
        """
        C: concept vector from frozen graph-text aligner, shape [batch, k]
        labels: ground-truth labels for classification
        """

        # Get gate logits from concept embeddings (shared MLP)
        gate_logits = self.gate_predictor(self.concept_embs).squeeze(-1)  # [k]
        gates = self.sample_gumbel_sigmoid(gate_logits)  # [k]

        # Broadcast gates across batch: [batch, k] * [k] → [batch, k]
        C_selected = C * gates

        # Predict label
        logits = self.classifier(C_selected)

        class_counts = torch.tensor([500, 300, 200], dtype=torch.float, device=logits.device)
        weights = 1.0 / class_counts
        weights = weights / weights.sum() * len(class_counts)
        
        pred_loss = F.cross_entropy(logits, labels, weight=class_weight)

        # KL loss over gates (compression)
        prior_logit = torch.full_like(gate_logits, self.prior_logit)
        kl_loss = self.kl_bernoulli(gate_logits, prior_logit)

        total_loss = pred_loss + self.beta * kl_loss
        return total_loss, logits.detach(), gates.detach()