"""
SCBM and baseline models.
"""

import os
import math
import torch
from torch import nn
from torch.distributions import RelaxedBernoulli, MultivariateNormal
import torch.nn.functional as F
from torchvision import models
from torch import Tensor
from typing import Optional


def freeze_module(m):
    m.eval()
    for param in m.parameters():
        param.requires_grad = False


def unfreeze_module(m):
    m.train()
    for param in m.parameters():
        param.requires_grad = True


class SCBM(nn.Module):
    """
    Stochastic Concept Bottleneck Model (SCBM) with Learned Covariance Matrix.

    This class implements a Stochastic Concept Bottleneck Model (SCBM) that extends concept prediction by incorporating
    a learned covariance matrix. The SCBM aims to capture the uncertainty and dependencies between concepts, providing
    a more robust and interpretable model for concept-based learning tasks.

    Key Features:
    - Predicts concepts along with a learned covariance matrix to model the relationships and uncertainties between concepts.
    - Supports various training modes and intervention strategies to improve model performance and interpretability.

    Args:
        config (dict): Configuration dictionary containing model and data settings.

    Noteworthy Attributes:
        training_mode (str): The training mode (e.g., "joint", "sequential", "independent").
        num_monte_carlo (int): The number of Monte Carlo samples for uncertainty estimation.
        straight_through (bool): Flag indicating whether to use straight-through gradients.
        curr_temp (float): The current temperature for the Gumbel-Softmax distribution.
        cov_type (str): The type of covariance matrix ("empirical", "global", or "amortized", where "empirical is fixed at start").

    Methods:
        forward(x, epoch, validation=False, c_true=None):
            Perform a forward pass through the model.
        intervene(c_mcmc_probs, c_mcmc_logits):
            Perform an intervention on the model's concept predictions.
    """

    def __init__(
        self,
        n_features,
        num_concepts,
        num_classes,
        head_arch='linear',
        training_mode='joint',
        num_monte_carlo=100,
        straight_through=True,
        cov_type='amortized',
        j_epochs=300,
        t_epochs=100,
        concept_learning='hard'
    ):
        super(SCBM, self).__init__()

        # Configuration arguments
        self.n_features = n_features
        self.num_concepts = num_concepts
        self.num_classes = num_classes
        self.head_arch = head_arch
        self.training_mode = training_mode
        self.num_monte_carlo = num_monte_carlo
        self.straight_through = straight_through
        self.curr_temp = 1.0
        self.concept_learning =concept_learning
        if self.training_mode == "joint":
            self.num_epochs = j_epochs
        else:
            self.num_epochs = t_epochs
        self.cov_type = cov_type

        self.mu_concepts = nn.Linear(self.n_features, self.num_concepts, bias=True)

        if self.cov_type == "global":
            self.sigma_concepts = nn.Parameter(
                torch.zeros(int(self.num_concepts * (self.num_concepts + 1) / 2))
            )  # Predict lower triangle of concept covariance
        elif self.cov_type == "empirical":
            self.sigma_concepts = torch.zeros(
                int(self.num_concepts * (self.num_concepts + 1) / 2)
            )
        else:
            self.sigma_concepts = nn.Linear(
                n_features,
                int(self.num_concepts * (self.num_concepts + 1) / 2),
                bias=True,
            )
            self.sigma_concepts.weight.data *= (
                0.01  # To prevent exploding precision matrix at initialization
            )

        # Assume binary concepts
        self.act_c = nn.Sigmoid()

        # Link function g(.)
        if self.num_classes == 2:
            self.pred_dim = 1
        elif self.num_classes > 2:
            self.pred_dim = self.num_classes

        if self.head_arch == "linear":
            fc_y = nn.Linear(self.num_concepts, self.pred_dim)
            self.head = nn.Sequential(fc_y)
        else:
            fc1_y = nn.Linear(self.num_concepts, 256)
            fc2_y = nn.Linear(256, self.pred_dim)
            self.head = nn.Sequential(fc1_y, nn.ReLU(), fc2_y)

    def forward(self, intermediate, epoch, validation=False, return_full=False, c_true=None):
        """
        Perform a forward pass through the Stochastic Concept Bottleneck Model (SCBM).

        This method performs a forward pass through the SCBM, predicting concept probabilities and logits for the target variable.

        Args:
            x (torch.Tensor): The input covariates. Shape: (batch_size, input_dims)
            epoch (int): The current epoch number.
            validation (bool, optional): Flag indicating whether this is a validation pass. Default is False.
            return_full (bool, optional): Flag indicating whether to also return mu of concept. Default is False.
            c_true (torch.Tensor, optional): The ground-truth concept values. Required for "independent" training mode. Default is None.

        Returns:
            tuple: A tuple containing:
                - c_mcmc_prob (torch.Tensor): MCMC samples for predicted concept probabilities. Shape: (batch_size, num_concepts, num_monte_carlo)
                - c_triang_cov (torch.Tensor): Cholesky decomposition of the concept logit covariance matrix. Shape: (batch_size, num_concepts, num_concepts)
                - y_pred_logits (torch.Tensor): Logits for the target variable. Shape: (batch_size, num_classes)
                - c_mu (torch.Tensor, optional): Predicted concept means. Shape: (batch_size, num_concepts). Returned if `return_full` is True.
        Notes:
            - The method first obtains intermediate representations from the encoder.
            - It then predicts the concept means and the Cholesky decomposition of the covariance matrix in the logit space.
            - The method samples from the predicted normal distribution to obtain concept logits and probabilities.
            - Depending on the training mode, it handles different strategies for sampling and backpropagation.
            - Finally, it predicts the target variable logits by averaging over multiple Monte Carlo samples.
        """

        # Get mu and cholesky decomposition of covariance
        c_mu = self.mu_concepts(intermediate)
        if self.cov_type == "global":
            c_sigma = self.sigma_concepts.repeat(c_mu.size(0), 1)
        elif self.cov_type == "empirical":
            c_sigma = self.sigma_concepts.unsqueeze(0).repeat(c_mu.size(0), 1, 1)
        else:
            c_sigma = self.sigma_concepts(intermediate)

        if self.cov_type == "empirical":
            c_triang_cov = c_sigma
        else:
            # Fill the lower triangle of the covariance matrix with the values and make diagonal positive
            c_triang_cov = torch.zeros(
                (c_sigma.shape[0], self.num_concepts, self.num_concepts),
                device=c_sigma.device,
            )
            rows, cols = torch.tril_indices(
                row=self.num_concepts, col=self.num_concepts, offset=0
            )
            diag_idx = rows == cols
            c_triang_cov[:, rows, cols] = c_sigma
            c_triang_cov[:, range(self.num_concepts), range(self.num_concepts)] = (
                F.softplus(c_sigma[:, diag_idx]) + 1e-6
            )

        # Sample from predicted normal distribution
        c_dist = MultivariateNormal(c_mu, scale_tril=c_triang_cov)
        c_mcmc_logit = c_dist.rsample([self.num_monte_carlo]).movedim(
            0, -1
        )  # [batch_size,num_concepts,mcmc_size]
        c_mcmc_prob = self.act_c(c_mcmc_logit)

        # For all MCMC samples simultaneously sample from Bernoulli
        if validation or self.training_mode == "sequential":
            # No backpropagation necessary
            c_mcmc = torch.bernoulli(c_mcmc_prob)
        elif self.training_mode == "independent":
            c_mcmc = c_true.unsqueeze(-1).repeat(1, 1, self.num_monte_carlo).float()
        else:
            # Backpropagation necessary
            curr_temp = self.compute_temperature(epoch, device=c_mcmc_prob.device)
            dist = RelaxedBernoulli(temperature=curr_temp, probs=c_mcmc_prob)

            # Bernoulli relaxation
            mcmc_relaxed = dist.rsample()
            if self.straight_through:
                # Straight-Through Gumbel Softmax
                mcmc_hard = (mcmc_relaxed > 0.5) * 1
                c_mcmc = mcmc_hard - mcmc_relaxed.detach() + mcmc_relaxed
            else:
                c_mcmc = mcmc_relaxed

        # MCMC loop for predicting label
        y_pred_probs_i = 0
        for i in range(self.num_monte_carlo):
            if self.concept_learning == "hard":
                c_i = c_mcmc[:, :, i]
            elif self.concept_learning == "soft":
                c_i = c_mcmc_logit[:, :, i]
            else:
                raise NotImplementedError
            y_pred_logits_i = self.head(c_i)
            if self.pred_dim == 1:
                y_pred_probs_i += torch.sigmoid(y_pred_logits_i)
            else:
                y_pred_probs_i += torch.softmax(y_pred_logits_i, dim=1)
        y_pred_probs = y_pred_probs_i / self.num_monte_carlo
        if self.pred_dim == 1:
            y_pred_logits = torch.logit(y_pred_probs, eps=1e-6)
        else:
            y_pred_logits = torch.log(y_pred_probs + 1e-6)

        # Return concept mu for interventions
        if return_full:
            return c_mcmc_prob, c_mu, c_triang_cov, y_pred_logits, None, None
        else:
            return c_mcmc_prob, c_triang_cov, y_pred_logits, None, None

    def intervene(self, c_mcmc_probs, c_mcmc_logits):
        y_pred_probs_i = 0
        c_hard = torch.bernoulli(c_mcmc_probs)
        for i in range(self.num_monte_carlo):
            if self.concept_learning == "soft":
                c_i = c_mcmc_logits[:, :, i]
            else:
                c_i = c_hard[:, :, i]

            y_pred_logits_i = self.head(c_i)
            if self.pred_dim == 1:
                y_pred_probs_i += torch.sigmoid(y_pred_logits_i)
            else:
                y_pred_probs_i += torch.softmax(y_pred_logits_i, dim=1)

        y_pred_probs = y_pred_probs_i / self.num_monte_carlo
        if self.pred_dim == 1:
            y_pred_logits = torch.logit(y_pred_probs, eps=1e-6)
        else:
            y_pred_logits = torch.log(y_pred_probs + 1e-6)

        return y_pred_logits

    def compute_temperature(self, epoch, device):
        final_temp = torch.tensor([0.5], device=device)
        init_temp = torch.tensor([1.0], device=device)
        rate = (math.log(final_temp) - math.log(init_temp)) / float(self.num_epochs)
        curr_temp = max(init_temp * math.exp(rate * epoch), final_temp)
        self.curr_temp = curr_temp
        return curr_temp

    def freeze_c(self):
        self.head.apply(freeze_module)

    def freeze_t(self):
        self.head.apply(unfreeze_module)
        self.mu_concepts.apply(freeze_module)
        if isinstance(self.sigma_concepts, nn.Linear):
            self.sigma_concepts.apply(freeze_module)
        else:
            self.sigma_concepts.requires_grad = False

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class GSCBM(nn.Module):
    """
    Stochastic Concept Bottleneck Model (SCBM) with Learned Covariance Matrix.

    This class implements a Stochastic Concept Bottleneck Model (SCBM) that extends concept prediction by incorporating
    a learned covariance matrix. The SCBM aims to capture the uncertainty and dependencies between concepts, providing
    a more robust and interpretable model for concept-based learning tasks.

    Key Features:
    - Predicts concepts along with a learned covariance matrix to model the relationships and uncertainties between concepts.
    - Supports various training modes and intervention strategies to improve model performance and interpretability.

    Args:
        config (dict): Configuration dictionary containing model and data settings.

    Noteworthy Attributes:
        training_mode (str): The training mode (e.g., "joint", "sequential", "independent").
        num_monte_carlo (int): The number of Monte Carlo samples for uncertainty estimation.
        straight_through (bool): Flag indicating whether to use straight-through gradients.
        curr_temp (float): The current temperature for the Gumbel-Softmax distribution.
        cov_type (str): The type of covariance matrix ("empirical", "global", or "amortized", where "empirical is fixed at start").

    Methods:
        forward(x, epoch, validation=False, c_true=None):
            Perform a forward pass through the model.
        intervene(c_mcmc_probs, c_mcmc_logits):
            Perform an intervention on the model's concept predictions.
    """

    def __init__(
        self,
        n_features,
        num_concepts,
        num_classes,
        head_arch='linear',
        training_mode='joint',
        num_monte_carlo=100,
        straight_through=True,
        cov_type='amortized',
        j_epochs=300,
        t_epochs=100,
        concept_learning='hard',
        temperature=0.1
    ):
        super(GSCBM, self).__init__()

        # Configuration arguments
        self.n_features = n_features
        self.num_concepts = num_concepts
        self.num_classes = num_classes
        self.head_arch = head_arch
        self.training_mode = training_mode
        self.num_monte_carlo = num_monte_carlo
        self.straight_through = straight_through
        self.curr_temp = 1.0
        self.concept_learning =concept_learning
        if self.training_mode == "joint":
            self.num_epochs = j_epochs
        else:
            self.num_epochs = t_epochs
        self.cov_type = cov_type

        self.mu_concepts = nn.Linear(self.n_features, self.num_concepts, bias=True)
        self.project_concepts = nn.Linear(self.n_features, self.num_concepts, bias=True)

        if self.cov_type == "global":
            self.sigma_concepts = nn.Parameter(
                torch.zeros(int(self.num_concepts * (self.num_concepts + 1) / 2))
            )  # Predict lower triangle of concept covariance
        elif self.cov_type == "empirical":
            self.sigma_concepts = torch.zeros(
                int(self.num_concepts * (self.num_concepts + 1) / 2)
            )
        else:
            self.sigma_concepts = nn.Linear(
                n_features,
                int(self.num_concepts * (self.num_concepts + 1) / 2),
                bias=True,
            )
            self.sigma_concepts.weight.data *= (
                0.01  # To prevent exploding precision matrix at initialization
            )

        # Assume binary concepts
        self.act_c = nn.Sigmoid()

        # Link function g(.)
        if self.num_classes == 2:
            self.pred_dim = 1
        elif self.num_classes > 2:
            self.pred_dim = self.num_classes

        if self.head_arch == "linear":
            fc_y = nn.Linear(self.num_concepts, self.pred_dim)
            self.head = nn.Sequential(fc_y)
        else:
            fc1_y = nn.Linear(self.num_concepts, 256)
            fc2_y = nn.Linear(256, self.pred_dim)
            self.head = nn.Sequential(fc1_y, nn.ReLU(), fc2_y)
        
        self.graph_init(num_concepts)
        
        
        self.sim = nn.CosineSimilarity(dim=-1)
        self.criterion = nn.CrossEntropyLoss()
        self.temperature = temperature

    def graph_init(self, num_concepts):

        self.edge_param1 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        self.edge_param2 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        self.edge_param3 = nn.Parameter(torch.empty((num_concepts, num_concepts)))
        
        self.edge_param1.requires_grad = True
        self.edge_param2.requires_grad = True
        self.edge_param3.requires_grad = True
        
        torch.nn.init.xavier_uniform(self.edge_param1)
        torch.nn.init.xavier_uniform(self.edge_param2)
        torch.nn.init.xavier_uniform(self.edge_param3)
        
        self.edge_param = [self.edge_param1, self.edge_param2, self.edge_param3]
        
        print('Graphs has been setup')
        
        return

    def get_characteristic_matrix(self, A):
        device = A.device
        I = torch.eye(A.size(1)).to(device)
        A = F.relu(A) + I
        A = torch.clamp(A, max=1, min=0)
        # print(A)
        D = torch.ceil(A).sum(dim=-1).to(device)
        # print(D)
        D_sqrt = D.pow(-0.5).unsqueeze(-1)

        A = D_sqrt * A
        A = D_sqrt * A.t()
        return A.t()
    
    def cl_loss(self, z1, z2, z3=None):
        device = z1.device
        # print(z1.shape, z2.shape)
        # # print('*'*10)
        # print(z1)
        # print(z1, z2)
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0)) / self.temperature
        # print(cos_sim)
        if z3 is not None:
            hard_negatives = self.sim(z1.unsqueeze(1), z3.unsqueeze(0)) / self.temperature
            cos_sim = torch.cat([cos_sim, hard_negatives], dim=-1)
            
        labels = torch.arange(cos_sim.size(0)).long().to(device)
        loss = self.criterion(cos_sim, labels)
        return loss

    def message_passing(self, stage1_output):

        attr_outputs = stage1_output.unsqueeze(-1)
        # print(self.use_relu)
        # if self.use_relu:
        #     attr_outputs = F.gelu(stage1_output).unsqueeze(-1)
        # elif self.use_sigmoid:
        #     attr_outputs = F.sigmoid(stage1_output).unsqueeze(-1)
        l1_regularizer = 0
        for edge_param in self.edge_param:
            adj = 0.5 * (edge_param + edge_param.t())
            new_adj = self.get_characteristic_matrix(adj)
            # attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs) + attr_outputs)
            attr_outputs = F.gelu(torch.matmul(new_adj, attr_outputs)) + attr_outputs
            l1_regularizer += 0.5*(torch.norm(new_adj, p=1, dim=-1).mean() + torch.norm(new_adj.T, p=1, dim=-1).mean())

        # if self.use_relu:
        #     attr_outputs = attr_outputs.squeeze() + F.gelu(stage1_output)
        # elif self.use_sigmoid:
        #     attr_outputs = attr_outputs.squeeze() + F.sigmoid(stage1_output)
        attr_outputs = attr_outputs.squeeze() #+ stage1_output
        
        return attr_outputs, l1_regularizer

    def forward(self, intermediate, epoch, validation=False, return_full=False, c_true=None):
        """
        Perform a forward pass through the Stochastic Concept Bottleneck Model (SCBM).

        This method performs a forward pass through the SCBM, predicting concept probabilities and logits for the target variable.

        Args:
            x (torch.Tensor): The input covariates. Shape: (batch_size, input_dims)
            epoch (int): The current epoch number.
            validation (bool, optional): Flag indicating whether this is a validation pass. Default is False.
            return_full (bool, optional): Flag indicating whether to also return mu of concept. Default is False.
            c_true (torch.Tensor, optional): The ground-truth concept values. Required for "independent" training mode. Default is None.

        Returns:
            tuple: A tuple containing:
                - c_mcmc_prob (torch.Tensor): MCMC samples for predicted concept probabilities. Shape: (batch_size, num_concepts, num_monte_carlo)
                - c_triang_cov (torch.Tensor): Cholesky decomposition of the concept logit covariance matrix. Shape: (batch_size, num_concepts, num_concepts)
                - y_pred_logits (torch.Tensor): Logits for the target variable. Shape: (batch_size, num_classes)
                - c_mu (torch.Tensor, optional): Predicted concept means. Shape: (batch_size, num_concepts). Returned if `return_full` is True.
        Notes:
            - The method first obtains intermediate representations from the encoder.
            - It then predicts the concept means and the Cholesky decomposition of the covariance matrix in the logit space.
            - The method samples from the predicted normal distribution to obtain concept logits and probabilities.
            - Depending on the training mode, it handles different strategies for sampling and backpropagation.
            - Finally, it predicts the target variable logits by averaging over multiple Monte Carlo samples.
        """

        # Get mu and cholesky decomposition of covariance
        c_mu = self.mu_concepts(intermediate)
        x_c = self.project_concepts(intermediate)
        
        cg_mu, l1_regularizer = self.message_passing(c_mu)
        
        cl_loss = self.cl_loss(c_mu, x_c)
        
        if self.cov_type == "global":
            c_sigma = self.sigma_concepts.repeat(c_mu.size(0), 1)
            cg_sigma = self.sigma_concepts.repeat(cg_mu.size(0), 1)
        elif self.cov_type == "empirical":
            c_sigma = self.sigma_concepts.unsqueeze(0).repeat(c_mu.size(0), 1, 1)
            cg_sigma = self.sigma_concepts.unsqueeze(0).repeat(cg_mu.size(0), 1, 1)
        else:
            c_sigma = self.sigma_concepts(intermediate)
            cg_sigma = self.sigma_concepts(intermediate)

        if self.cov_type == "empirical":
            c_triang_cov = c_sigma
            cg_triang_cov = cg_sigma
        else:
            # Fill the lower triangle of the covariance matrix with the values and make diagonal positive
            c_triang_cov = torch.zeros(
                (c_sigma.shape[0], self.num_concepts, self.num_concepts),
                device=c_sigma.device,
            )
            rows, cols = torch.tril_indices(
                row=self.num_concepts, col=self.num_concepts, offset=0
            )
            diag_idx = rows == cols
            c_triang_cov[:, rows, cols] = c_sigma
            c_triang_cov[:, range(self.num_concepts), range(self.num_concepts)] = (
                F.softplus(c_sigma[:, diag_idx]) + 1e-6
            )
            
            cg_triang_cov = torch.zeros(
                (cg_sigma.shape[0], self.num_concepts, self.num_concepts),
                device=c_sigma.device,
            )
            rows, cols = torch.tril_indices(
                row=self.num_concepts, col=self.num_concepts, offset=0
            )
            diag_idx = rows == cols
            cg_triang_cov[:, rows, cols] = cg_sigma
            cg_triang_cov[:, range(self.num_concepts), range(self.num_concepts)] = (
                F.softplus(cg_sigma[:, diag_idx]) + 1e-6
            )

        # Sample from predicted normal distribution
        c_dist = MultivariateNormal(c_mu, scale_tril=c_triang_cov)
        c_mcmc_logit = c_dist.rsample([self.num_monte_carlo]).movedim(
            0, -1
        )  # [batch_size,num_concepts,mcmc_size]
        c_mcmc_prob = self.act_c(c_mcmc_logit)
        
        cg_dist = MultivariateNormal(cg_mu, scale_tril=cg_triang_cov)
        cg_mcmc_logit = cg_dist.rsample([self.num_monte_carlo]).movedim(
            0, -1
        )  # [batch_size,num_concepts,mcmc_size]
        cg_mcmc_prob = self.act_c(cg_mcmc_logit)

        # For all MCMC samples simultaneously sample from Bernoulli
        if validation or self.training_mode == "sequential":
            # No backpropagation necessary
            c_mcmc = torch.bernoulli(c_mcmc_prob)
            cg_mcmc = torch.bernoulli(cg_mcmc_prob)
        elif self.training_mode == "independent":
            c_mcmc = c_true.unsqueeze(-1).repeat(1, 1, self.num_monte_carlo).float()
            cg_mcmc = c_true.unsqueeze(-1).repeat(1, 1, self.num_monte_carlo).float()
        else:
            # Backpropagation necessary
            curr_temp = self.compute_temperature(epoch, device=cg_mcmc_prob.device)
            dist = RelaxedBernoulli(temperature=curr_temp, probs=cg_mcmc_prob)

            # Bernoulli relaxation
            mcmc_relaxed = dist.rsample()
            if self.straight_through:
                # Straight-Through Gumbel Softmax
                mcmc_hard = (mcmc_relaxed > 0.5) * 1
                cg_mcmc = mcmc_hard - mcmc_relaxed.detach() + mcmc_relaxed
            else:
                cg_mcmc = mcmc_relaxed

        # MCMC loop for predicting label
        y_pred_probs_i = 0
        for i in range(self.num_monte_carlo):
            if self.concept_learning == "hard":
                c_i = cg_mcmc[:, :, i]
            elif self.concept_learning == "soft":
                c_i = cg_mcmc_logit[:, :, i]
            else:
                raise NotImplementedError
            y_pred_logits_i = self.head(c_i)
            if self.pred_dim == 1:
                y_pred_probs_i += torch.sigmoid(y_pred_logits_i)
            else:
                y_pred_probs_i += torch.softmax(y_pred_logits_i, dim=1)
        y_pred_probs = y_pred_probs_i / self.num_monte_carlo
        if self.pred_dim == 1:
            y_pred_logits = torch.logit(y_pred_probs, eps=1e-6)
        else:
            y_pred_logits = torch.log(y_pred_probs + 1e-6)

        # Return concept mu for interventions
        if return_full:
            return c_mcmc_prob, c_mu, c_triang_cov, y_pred_logits, cl_loss, l1_regularizer
        else:
            return c_mcmc_prob, c_triang_cov, y_pred_logits, cl_loss, l1_regularizer

    def intervene(self, c_mcmc_probs, c_mcmc_logits):
        y_pred_probs_i = 0
        c_hard = torch.bernoulli(c_mcmc_probs)
        for i in range(self.num_monte_carlo):
            if self.concept_learning == "soft":
                c_i = c_mcmc_logits[:, :, i]
            else:
                c_i = c_hard[:, :, i]

            y_pred_logits_i = self.head(c_i)
            if self.pred_dim == 1:
                y_pred_probs_i += torch.sigmoid(y_pred_logits_i)
            else:
                y_pred_probs_i += torch.softmax(y_pred_logits_i, dim=1)

        y_pred_probs = y_pred_probs_i / self.num_monte_carlo
        if self.pred_dim == 1:
            y_pred_logits = torch.logit(y_pred_probs, eps=1e-6)
        else:
            y_pred_logits = torch.log(y_pred_probs + 1e-6)

        return y_pred_logits

    def compute_temperature(self, epoch, device):
        final_temp = torch.tensor([0.5], device=device)
        init_temp = torch.tensor([1.0], device=device)
        rate = (math.log(final_temp) - math.log(init_temp)) / float(self.num_epochs)
        curr_temp = max(init_temp * math.exp(rate * epoch), final_temp)
        self.curr_temp = curr_temp
        return curr_temp

    def freeze_c(self):
        self.head.apply(freeze_module)

    def freeze_t(self):
        self.head.apply(unfreeze_module)
        self.mu_concepts.apply(freeze_module)
        if isinstance(self.sigma_concepts, nn.Linear):
            self.sigma_concepts.apply(freeze_module)
        else:
            self.sigma_concepts.requires_grad = False

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class SCBLoss(nn.Module):
    """
    Loss function for the Stochastic Concept Bottleneck Model (SCBM).
    """

    def __init__(
        self, num_classes: Optional[int] = 2, alpha: float = 1, config: dict = {}
    ) -> None:
        """
        Initialize the SCBLoss.

        Args:
            num_classes (int, optional): Number of target classes.
            alpha (float, optional): Weight for joint training.
            config (dict, optional): Configuration dictionary.
        """
        super(SCBLoss, self).__init__()
        self.num_classes = num_classes
        self.alpha = alpha
        self.reg_precision = 'l1'
        self.reg_weight = 1

    def forward(
        self,
        concepts_mcmc_probs: Tensor,
        concepts_true: Tensor,
        target_pred_logits: Tensor,
        target_true: Tensor,
        c_triang_cov: Tensor,
        cov_not_triang=False,
    ) -> Tensor:
        """
        Compute the loss.

        Args:
            concepts_mcmc_probs (Tensor): MCMC matrix of predicted concept probabilities.
            concepts_true (Tensor): Ground-truth concept values.
            target_pred_logits (Tensor): Predicted target logits.
            target_true (Tensor): Ground-truth target values.
            c_triang_cov (Tensor): Cholesky decomposition of the concept covariance matrix.
            cov_not_triang (bool, optional): Flag indicating if the covariance is in its cholesky form or already the covariance.

        Returns:
            Tensor: Target loss, concept loss, precision loss, and total loss.
        """

        assert torch.all((concepts_true == 0) | (concepts_true == 1))
        concepts_true_expanded = concepts_true.unsqueeze(-1).expand_as(
            concepts_mcmc_probs
        )

        bce_loss = F.binary_cross_entropy(
            concepts_mcmc_probs, concepts_true_expanded.float(), reduction="none"
        )  # [B,C,MCMC]
        intermediate_concepts_loss = -torch.sum(bce_loss, dim=1)  # [B,MCMC]
        mcmc_loss = -torch.logsumexp(
            intermediate_concepts_loss, dim=1
        )  # [B], logsumexp for numerical stability due to shift invariance
        concepts_loss = self.alpha * torch.mean(mcmc_loss)

        if self.num_classes == 2:
            # Logits to probs
            target_pred_probs = nn.Sigmoid()(target_pred_logits.squeeze(1))
            target_loss = F.binary_cross_entropy(
                target_pred_probs, target_true.float(), reduction="mean"
            )
        else:
            target_loss = F.cross_entropy(
                target_pred_logits, target_true.long(), reduction="mean"
            )

        # Add precision loss
        if self.reg_precision == "l1":
            if cov_not_triang:
                prec_matrix = torch.inverse(c_triang_cov)
            else:
                c_triang_inv = torch.inverse(c_triang_cov)
                prec_matrix = torch.matmul(
                    torch.transpose(c_triang_inv, dim0=1, dim1=2), c_triang_inv
                )
            prec_loss = prec_matrix.abs().sum(dim=(1, 2)) - prec_matrix.diagonal(
                offset=0, dim1=1, dim2=2
            ).abs().sum(-1)
            if prec_matrix.size(1) > 1:
                prec_loss = prec_loss / (
                    prec_matrix.size(1) * (prec_matrix.size(1) - 1)
                )
            else:  # Univariate case, can happen when intervening
                prec_loss = prec_loss
            prec_loss = self.reg_weight * prec_loss.mean(-1)
        else:
            prec_loss = torch.zeros_like(concepts_loss)

        total_loss = target_loss + concepts_loss + prec_loss

        return target_loss, concepts_loss, prec_loss, total_loss