import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time

def get_loss_class(mix_type,device):
    """
    Return the algorithm class with the given name.
    For Mix method Ablation
    """
    if mix_type == 'cross':
        loss_name = 'PositiveSupCon_loss'
    elif mix_type == 'within':
        loss_name = 'PositiveSupCon_loss2'
    elif mix_type == 'all':
        loss_name = 'PositiveSupCon_loss3'
    elif mix_type == 'noMix':
        loss_name = 'PositiveSupCon_loss_'
    if loss_name not in globals():
        raise NotImpelemntedError("Algorithm not found : {}".format(loss_name))
    return globals()[loss_name](device)


class MMD_loss(nn.Module):
    
    def __init__(self, kernel_type = 'rbf', kernel_mul = 2.0, kernel_num=5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def gaussian_kernel(self, source, target, kernel_mul = 2.0, kernel_num=5, fix_sigma=None):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1))
        )
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1))
        )
        L2_distance = ((total0 - total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else :
            bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
        bandwidth /= kernel_mul ** (kernel_num//2)
        bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance/bandwidth_temp) for bandwidth_temp in bandwidth_list]
        
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.gaussian_kernel(source, target, kernel_mul = self.kernel_mul, kernel_num = self.kernel_num, fix_sigma=self.fix_sigma)
            with torch.no_grad():
                XX = torch.mean(kernels[:batch_size, :batch_size])
                YY = torch.mean(kernels[batch_size:, batch_size:])
                XY = torch.mean(kernels[:batch_size, batch_size:])
                YX = torch.mean(kernels[batch_size:, :batch_size])
                loss = torch.mean(XX+YY-XY-YX)
            torch.cuda.empty_cache()
            return loss

class PositiveSupCon_loss(nn.Module):
    """ Supervised Contrastive Loss for positive pairs using momentum encoder """
    def __init__(self, device, temperature = 0.5, alpha=1.0):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.device = device

    def forward(self, src_feat, trg_feat, src_y, trg_y):
        src_feat, src_y = src_feat.to(self.device), src_y.to(self.device)
        trg_feat, trg_y = trg_feat.to(self.device), trg_y.to(self.device)
        
        # Normalize features
        src_feat = F.normalize(src_feat, dim=1)
        trg_feat = F.normalize(trg_feat, dim=1)

        # Concatenate features and labels
        all_features = torch.cat([src_feat, trg_feat], dim=0).to(self.device)
        all_labels = torch.cat([src_y, trg_y], dim=0).to(self.device)

        # step1. Create a mask to identify matching classses between src_y and trg_y
        same_class_mask = src_y.unsqueeze(1) == trg_y.unsqueeze(0) # shape : (src_batch_size , trg_batch_size)

        # step2. Perform mixup for all matching class instances (vectorized)
        if same_class_mask.any(): # only proceed if there are any matching classes
            # Expand the source and target features to matchi the mask dimensions
            src_feat_expanded = src_feat.unsqueeze(1).expand(-1, trg_feat.size(0),-1) # shape : (src_batch_size, trg_batch_size, feat_dim)
            trg_feat_expanded = trg_feat.unsqueeze(0).expand(src_feat.size(0),-1,-1) # shape : (src_batch_size, trg_batch_size, feat_dim)
            # Select features where the mask is True (same class)
            src_feat_same_class = src_feat_expanded[same_class_mask] # flatten vector of matching src_feat
            trg_feat_same_class = trg_feat_expanded[same_class_mask] # flatten vector of matching trg_feat

            #Perform mixup on matching features
            mixed_features = self.mixup(src_feat_same_class, trg_feat_same_class, self.alpha)
            # The labels will be taken from src_y as both classes are the same
            mixed_labels = src_y.unsqueeze(1).expand(-1,trg_feat.size(0))[same_class_mask]

            # update all features and labels with mixup
            all_features = torch.cat([all_features, mixed_features], dim=0)
            all_labels = torch.cat([all_labels, mixed_labels], dim=0)

        # Contrastive loss
        sim_matrix = torch.matmul(all_features, all_features.T) / self.temperature
        mask = torch.eq(all_labels.unsqueeze(1), all_labels.unsqueeze(0)).float().to(self.device)

        # Calculating the loss, ensuring all tensors are on the GPU
        exp_sim_matrix = torch.exp(sim_matrix).to(self.device)
        masked_exp_sim_matrix = exp_sim_matrix * mask
        loss = -torch.log(torch.sum(masked_exp_sim_matrix, dim=1)/torch.sum(exp_sim_matrix,dim=1))

        return torch.mean(loss)
        
    def mixup(self, x1, x2, alpha):
        # Perform the mixup operation, ensuring the result is on the GPU
        if alpha > 0:
            lam = torch.distributions.Beta(alpha, alpha).sample().to(self.device)
        else:
            lam = torch.tensor(1.0).to(self.device)
        return lam * x1 + (1-lam)*x2


class PositiveSupCon_loss2(PositiveSupCon_loss):
    """ Supervised Contrastive Loss for positive pairs using momentum encoder """
    """ Only Domain-wise mixup, Not Domain-crossmixup"""
    def __init__(self, device, temperature = 0.5, alpha=1.0):
        super().__init__(device)
        
        self.temperature = temperature
        self.alpha = alpha
        self.device = device

    def forward(self, src_feat, trg_feat, src_y, trg_y):
        src_feat, src_y = src_feat.to(self.device), src_y.to(self.device)
        trg_feat, trg_y = trg_feat.to(self.device), trg_y.to(self.device)

        src_feat = F.normalize(src_feat, dim=1)
        trg_feat = F.normalize(trg_feat, dim=1)

        # Concatenate features and labels
        all_features = torch.cat([src_feat, trg_feat], dim=0).to(self.device)
        all_labels = torch.cat([src_y, trg_y], dim=0).to(self.device)

        # Perform mixup for source and target independently
        mixed_src_features, mixed_src_labels = self.mixup_within_class(src_feat, src_y)
        mixed_trg_features, mixed_trg_labels = self.mixup_within_class(trg_feat, trg_y)

        # Update all features and labels with mixed features
        all_features = torch.cat([all_features, mixed_src_features, mixed_trg_features], dim=0)
        all_labels = torch.cat([all_labels, mixed_src_labels, mixed_trg_labels], dim=0)

        # Contrastive loss
        sim_matrix = torch.matmul(all_features, all_features.T) / self.temperature
        mask = torch.eq(all_labels.unsqueeze(1), all_labels.unsqueeze(0)).float().to(self.device)

        # Calculating the loss
        exp_sim_matrix = torch.exp(sim_matrix).to(self.device)
        masked_exp_sim_matrix = exp_sim_matrix * mask
        loss = -torch.log(torch.sum(masked_exp_sim_matrix, dim=1)/torch.sum(exp_sim_matrix, dim=1))

        return torch.mean(loss)

    def mixup_within_class(self, features, labels):
        """
        Perform mixup within the same class for the given features ans labels.
        Ensures that each class has the same number of samples for mixup.
        """
        unique_classes = torch.unique(labels)
        mixed_features = []
        mixed_labels = []

        for cls in unique_classes:
            # Get all features and labels for the current class
            cls_mask = labels == cls
            cls_features = features[cls_mask]

            # Ensure equal number of samples per class
            num_samples = cls_features.size(0)
            if num_samples < 2 : # skip mixup if not enough samples
                continue

            # Randomly shuffle and pair samples for mixup
            temp = torch.randperm(num_samples)
            shuffled_features = cls_features[temp]

            # Perform mixup (use half?? num_samples//2)
            for i in range(num_samples): # Take pairs (num_samples => mix_N ?)
                # if i+1<num_samples:
                mixed = self.mixup(cls_features[i], shuffled_features[i], self.alpha)
                mixed_features.append(mixed)
                mixed_labels.append(cls)

        if mixed_features:
            return torch.stack(mixed_features), torch.tensor(mixed_labels).to(self.device)
        else:
            # Return empty tensors if no mixed features were created
            return torch.empty(0, features.size(1)).to(self.device), torch.empty(0, dtype=labels.dtype).to(self.device)

class PositiveSupCon_loss3(PositiveSupCon_loss2):
    """ Supervised Contrastive Loss for positive pairs using momentum encoder """
    """ Only Domain-wise mixup, Not Domain-crossmixup"""
    def __init__(self, device, temperature = 0.5, alpha=1.0):
        super().__init__(device)
        
        self.temperature = temperature
        self.alpha = alpha
        self.device = device

    def forward(self, src_feat, trg_feat, src_y, trg_y):
        src_feat, src_y = src_feat.to(self.device), src_y.to(self.device)
        trg_feat, trg_y = trg_feat.to(self.device), trg_y.to(self.device)

        src_feat = F.normalize(src_feat, dim=1)
        trg_feat = F.normalize(trg_feat, dim=1)

        # Concatenate features and labels
        all_features = torch.cat([src_feat, trg_feat], dim=0).to(self.device)
        all_labels = torch.cat([src_y, trg_y], dim=0).to(self.device)

        # step1. Create a mask to identify matching classes between src and trg
        same_class_mask = src_y.unsqueeze(1) == trg_y.unsqueeze(0) # shape : (src_batch_size, trg_batch_size)

        # step2. Perform mixup for all matching class instances (vectorized)
        if same_class_mask.any(): # only proceed if there are any matching classes
            # Expand the source and target features to match the mask dimensions
            src_feat_expanded = src_feat.unsqueeze(1).expand(-1, trg_feat.size(0),-1) # shape: (src_batch_size, trg_batch_size, feat_dim)
            
            trg_feat_expanded = trg_feat.unsqueeze(0).expand(src_feat.size(0), -1,-1)# shape: (src_batch_size, trg_batch_size, feat_dim)
            
            src_feat_same_class = src_feat_expanded[same_class_mask] # flatten vector of matching src_feat
            trg_feat_same_class = trg_feat_expanded[same_class_mask] # flatten vector of matching trg_feat
            
            # Perform mixup on matching features
            mixed_features = self.mixup(src_feat_same_class, trg_feat_same_class, self.alpha)
            
            # The labels will be taken from src_y as both classes the same
            mixed_labels = src_y.unsqueeze(1).expand(-1, trg_feat.size(0))[same_class_mask]

            # Update all features and labels with mixed features
            all_features = torch.cat([all_features, mixed_features], dim=0)
            all_labels = torch.cat([all_labels, mixed_labels], dim=0)

        # Perform mixup within source and traget independently
        mixed_src_features, mixed_src_labels = self.mixup_within_class(src_feat, src_y)
        mixed_trg_features, mixed_trg_labels = self.mixup_within_class(trg_feat, trg_y)

        # Update all features and labels with within-class mixed features
        all_features = torch.cat([all_features, mixed_src_features, mixed_trg_features], dim=0)
        all_labels = torch.cat([all_labels, mixed_src_labels, mixed_trg_labels], dim=0)

        # Contrastive loss
        sim_matrix = torch.matmul(all_features, all_features.T) / self.temperature
        mask = torch.eq(all_labels.unsqueeze(1), all_labels.unsqueeze(0)).float().to(self.device)

        # Calculating the loss
        exp_sim_matrix = torch.exp(sim_matrix).to(self.device)
        masked_exp_sim_matrix = exp_sim_matrix * mask
        loss = -torch.log(torch.sum(masked_exp_sim_matrix, dim=1)/torch.sum(exp_sim_matrix,dim=1))

        return torch.mean(loss)


class PositiveSupCon_loss_(nn.Module):
    """ Supervised Contrastive Loss for positive pairs using momentum encoder """
    def __init__(self, device, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        self.device = device

    def forward(self, src_feat, trg_feat, src_y, trg_y):
        # Move features and labels to the specified device
        src_feat, src_y = src_feat.to(self.device), src_y.to(self.device)
        trg_feat, trg_y = trg_feat.to(self.device), trg_y.to(self.device)

        # Normalize features
        src_feat = F.normalize(src_feat, dim=1)
        trg_feat = F.normalize(trg_feat, dim=1)

        # Concatenate features and labels
        all_features = torch.cat([src_feat, trg_feat], dim=0).to(self.device)
        all_labels = torch.cat([src_y, trg_y], dim=0).to(self.device)

        # Step 1: Create a mask to identify matching classes between samples
        same_class_mask = all_labels.unsqueeze(1) == all_labels.unsqueeze(0)  # Shape: (batch_size, batch_size)

        # Step 2: Calculate similarity matrix
        sim_matrix = torch.matmul(all_features, all_features.T) / self.temperature

        # Step 3: Mask similarity for same-class pairs only
        positive_sim = sim_matrix * same_class_mask.float()  # Only retain similarities for matching class pairs

        # Step 4: Compute contrastive loss
        exp_sim_matrix = torch.exp(sim_matrix)  # Compute exp(similarity)
        exp_positive_sim = torch.exp(positive_sim)  # Compute exp(similarity) for positive pairs

        # Sum over positive similarities for each sample
        positive_sum = torch.sum(exp_positive_sim, dim=1)  # Shape: (batch_size,)

        # Sum over all similarities for each sample
        total_sum = torch.sum(exp_sim_matrix, dim=1)  # Shape: (batch_size,)

        # Avoid division by zero with small epsilon
        loss = -torch.log((positive_sum + 1e-8) / (total_sum + 1e-8))

        return torch.mean(loss)
        
#####################################################
### loss for baselines ##############################
#####################################################

class CDAC_loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, hparams, backbone, classifier, im_data, im_data_bar, im_data_bar2, BCE, w_cons, device, target):
        aac_loss, pl_loss, con_loss, feat = self.get_losses_unlabeled(hparams, backbone, classifier, im_data, im_data_bar, im_data_bar2, BCE, w_cons, device, target)

        return aac_loss + pl_loss + con_loss, feat

    def get_losses_unlabeled(self, hparams, G, F1, im_data, im_data_bar, im_data_bar2, BCE, w_cons, device, target):
        """ Get losses for unlabeled samples. """
        feat = G(im_data) # G: feature extractor
        feat_bar = G(im_data_bar)
        feat_bar2 = G(im_data_bar2)

        output = F1(feat, reverse=True, eta=1.0) # F1: classifier
        output_bar = F1(feat_bar, reverse=True, eta=1.0)
        prob, prob_bar = F.softmax(output, dim=1), F.softmax(output_bar, dim=1)

        # loss for adversarial adaptive clusetering
        aac_loss = self.advbce_unlabeled(hparams, target=target, feat=feat, prob=prob, prob_bar=prob_bar, device=device, bce=BCE)

        output = F1(feat)
        output_bar = F1(feat_bar)
        output_bar2 = F1(feat_bar2)

        prob = F.softmax(output, dim=1)
        prob_bar = F.softmax(output_bar, dim=1)
        prob_bar2 = F.softmax(output_bar2, dim=1)

        max_probs, pseudo_labels = torch.max(prob.detach_(), dim=-1)
        mask = max_probs.ge(hparams['threshold']).float()

        # loss for pseudo labeling
        pl_loss = (F.cross_entropy(output_bar2, pseudo_labels, reduction='none')*mask).mean()

        # loss for consistency
        con_loss = w_cons * F.mse_loss(prob_bar, prob_bar2)

        return aac_loss, pl_loss, con_loss, feat

    def advbce_unlabeled(self, hparams, target, feat, prob, prob_bar, device, bce):
        """ Construct adversarial adaptive clustering loss."""
        target_ulb = self.pairwise_target(hparams, feat, target, device)
        prob_bottleneck_row, _ = self.PairEnum2D(prob)
        _, prob_bottleneck_col = self.PairEnum2D(prob_bar)
        adv_bce_loss = -bce(prob_bottleneck_row, prob_bottleneck_col, target_ulb)
        return adv_bce_loss

    def pairwise_target(self, hparams, feat, target, device):
        """ Produce pairwise similarity label. """
        feat_detach = feat.detach()
        # For unlabeled data
        if target is None:
            rank_feat = feat_detach
            rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
            rank_idx1, rank_idx2 = self.PairEnum2D(rank_idx)
            rank_idx1, rank_idx2 = rank_idx1[:,:hparams['topk']], rank_idx2[:, :hparams['topk']]
            rank_idx1, _ = torch.sort(rank_idx1, dim=1)
            rank_idx2, _ = torch.sort(rank_idx2, dim=1)
            rank_diff = rank_idx1 - rank_idx2
            rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
            target_ulb = torch.ones_like(rank_diff).float().to(device)
            target_ulb[rank_diff > 0] = 0
        # For labeled data
        elif target is not None:
            target_row, target_col = self.PairEnum1D(target)
            target_ulb = torch.zeros(target.size(0) * target.size(0)).float().to(device)
            target_ulb[target_row == target_col] = 1
        else:
            raise ValueError('Please check your target.')
        return target_ulb

    def PairEnum1D(self, x):
        """ Enumerate all pairs of feature in x with 1 dimension. """
        assert x.ndimension() == 1, 'Input dimension must be 1'
        x1 = x.repeat(x.size(0), )
        x2 = x.repeat(x.size(0)).view(-1, x.size(0)).transpose(1, 0).reshape(-1)
        return x1, x2

    def PairEnum2D(self, x):
        """ Enumerate all pairs of feature in x with 2 dimensions. """
        assert x.ndimension() == 2, 'Input dimension must be 2'
        x1 = x.repeat(x.size(0), 1)
        x2 = x.repeat(1, x.size(0)).view(-1, x.size(1))
        return x1, x2

class BCE(nn.Module):
    eps = 1e-7

    def forward(self, prob1, prob2, simi):
        P = prob1.mul_(prob2)
        P = P.sum(1)
        P.mul_(simi).add_(simi.eq(-1).type_as(P))
        neglogP = -P.add_(BCE.eps).log_()
        return neglogP.mean()

class BCE_softlabels(nn.Module):
    """ Construct binary cross-entropy loss."""
    eps = 1e-7

    def forward(self, prob1, prob2, simi):
        P = prob1 * prob2
        P = P.sum(1)
        neglogP = -(simi*torch.log(P+BCE.eps) + (1.-simi)*torch.log(1.-P+BCE.eps))
        return neglogP.mean()

class CrossEntropyWLogits(torch.nn.Module):
    def __init__(self, reduction='mean'):
        # can support different kinds of reduction if needed
        super(CrossEntropyWLogits, self).__init__()
        assert reduction == 'mean' or reduction =='none', 'utils.loss.CrossEntropyWLogits : reduction not recognized'
        self.reduction = reduction

    def forward(self, logits, targets):
        # shape of targets needs to match that of preds
        log_preds = torch.log_softmax(logits, dim=1)
        if self.reduction =='mean':
            return torch.mean(torch.sum(-targets * log_preds, dim=1), dim=0)
        else:
            return torch.sum(-targets*log_preds, dim=1)
            
class AdaMatch_loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, hparams, backbone, classifier, 
                im_data_src, im_data_bar_src, im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul,
               gt_labels_src, gt_labels_trg, step, warm_steps, device):
        
        loss, source_loss, target_loss = self.get_losses(hparams, backbone, classifier,
                                                        im_data_src, im_data_bar_src, im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul,
                                                        gt_labels_src, gt_labels_trg, step, warm_steps, device)

        return loss, source_loss, target_loss

    def get_losses(self, hparams, feature_extractor, classifier, im_data_src, im_data_bar_src, im_data_trg, im_data_bar_trg,
              im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg, step, warm_steps, device):
        all_labeled_img = torch.cat([im_data_src, im_data_trg])
        all_labeled_bar_img = torch.cat([im_data_bar_src, im_data_bar_trg])
        labels_source = torch.cat([gt_labels_src, gt_labels_trg]) # labeled data -> source

        has_unlabeled = im_data_trg_ul.numel() > 0
        
        if im_data_trg_ul.any():
            data_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img], dim=0)
        else:
            data_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img, im_data_trg_ul, im_data_bar_trg_ul], dim=0)
    
        source_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img], dim=0)
        data_combined = feature_extractor(data_combined_img)
        
        source_total = source_combined_img.size(0)
    
        # 로짓 계산
        logits_combined = classifier(data_combined)
        logits_source_p = logits_combined[:source_total]
    
        self._disable_batchnorm_tracking(feature_extractor)
        self._disable_batchnorm_tracking(classifier)
        source_combined = feature_extractor(source_combined_img)
        logits_source_pp = classifier(source_combined)
        self._enable_batchnorm_tracking(feature_extractor)
        self._enable_batchnorm_tracking(classifier)
    
        lamb = torch.rand_like(logits_source_p).to(device)
        final_logits_source = (lamb*logits_source_p) + ((1-lamb)*logits_source_pp)
    
        logits_source_weak = final_logits_source[:all_labeled_img.size(0)]
        pseudolabels_source = F.softmax(logits_source_weak, dim=1)
        
        # perform relative confidence thresholding
        row_wise_max, _ = torch.max(pseudolabels_source, dim=1)
        final_sum = torch.mean(row_wise_max,0)
        
        # define relative confidence thrshold
        c_tau = hparams['tau'] * final_sum
        
        if has_unlabeled:
            # softmax for logits of weakly augmented target images
            logits_target = logits_combined[source_total:]
            logits_target_weak = logits_target[:im_data_trg_ul.size(0)]
            pseudolabels_target = F.softmax(logits_target_weak, dim=1)
    
            # align target label distribution to source label distribution
            expectation_ratio = (1e-6 + torch.mean(pseudolabels_source, dim=0)) / (1e-6 + torch.mean(pseudolabels_target, dim=0))
            final_pseudolabels = F.normalize((pseudolabels_target * expectation_ratio), p=1, dim=1) # L1 normalization
            max_values, final_pseudolabels_cls = torch.max(final_pseudolabels, dim=1)
            mask = (max_values >= c_tau).float()

####
        # Debugging
            valid_samples = (max_values >= c_tau).sum()
            print(f"Threshold: {c_tau:.4f}, Valid samples: {valid_samples}/{max_values.size(0)}")
            
            if valid_samples > 0:
                # Only compute target loss if we have valid pseudo-labels
                source_loss = self._compute_source_loss(logits_source_weak, final_logits_source[all_labeled_img.size(0):], labels_source)
                pseudolabels = final_pseudolabels_cls.detach()
                target_loss = self._compute_target_loss(pseudolabels, logits_target[im_data_trg_ul.size(0):], mask)
            else:
                # No valid pseudo-labels, only use source loss
                source_loss = self._compute_source_loss(logits_source_weak, final_logits_source[all_labeled_img.size(0):], labels_source)
                target_loss = torch.tensor(0.0, device=device)
            
        else:
            source_loss = self._compute_source_loss(logits_source_weak, final_logits_source[all_labeled_img.size(0):], labels_source)
            target_loss = torch.tensor(0.0, device=device)
    
        # compute target loss weight (mu)
        pi = torch.tensor(np.pi, dtype=torch.float).to(device)
        step = torch.tensor(step, dtype=torch.float).to(device)
        mu = 0.5 - torch.cos(torch.minimum(pi, (pi*step)/(warm_steps+1e-5)))/2
        
        source_loss = torch.nan_to_num(source_loss, nan=0.0)
        target_loss = torch.nan_to_num(target_loss, nan=0.0)

        # get total loss
        loss = source_loss + (mu*target_loss)
        
        return loss, source_loss, target_loss
    
    @staticmethod
    def _compute_source_loss(logits_weak, logits_strong, labels):
        """
        Receives logits as input (dense layer outputs with no activation function)
        """
        loss_function = nn.CrossEntropyLoss() # default : 'reduction="mean"'
        weak_loss = loss_function(logits_weak, labels)
        strong_loss = loss_function(logits_strong, labels)
        
        return (weak_loss + strong_loss) / 2

    
    @staticmethod
    def _compute_target_loss(pseudolabels, logits_strong, mask):
        """
        Receives logits as input (dense layer outputs with no activation function).
        `pseudolabels` are treated as ground truth (standard SSL practice).
        """
        loss_function = nn.CrossEntropyLoss(reduction="none")

        loss = loss_function(logits_strong, pseudolabels)

        return (loss * mask).mean()


    @staticmethod
    def _disable_batchnorm_tracking(model):
        def fn(module):
            if isinstance(module, nn.modules.batchnorm._BatchNorm):
                module.track_running_stats = False

        model.apply(fn)

    @staticmethod
    def _enable_batchnorm_tracking(model):
        def fn(module):
            if isinstance(module, nn.modules.batchnorm._BatchNorm):
                module.track_running_stats = True
        model.apply(fn)

class dst_loss(AdaMatch_loss):
    def __init__(self):
        super().__init__()

    def forward(self, hparams, backbone, classifier, classifier2, im_data_src, im_data_bar_src, 
               im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg,
               step, warm_steps, device, ablation=''):
        loss1, loss2 = self.get_losses(hparams, backbone, classifier, classifier2,
                                      im_data_src, im_data_bar_src, im_data_trg, im_data_bar_trg,
                                      im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg,
                                      step, warm_steps, device, ablation)
        return loss1, loss2

    def get_losses(self, hparams, feature_extractor, classifier, classifier2, im_data_src, im_data_bar_src, im_data_trg, im_data_bar_trg,
                  im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg, step, warm_steps, device, ablation):
        """
        in this function, source refers to labeled samples, target refers to unlabeled samples
        """
        all_labeled_img = torch.cat([im_data_src, im_data_trg])
        all_labeled_bar_img = torch.cat([im_data_bar_src, im_data_bar_trg])
        labels_source = torch.cat([gt_labels_src, gt_labels_trg])
        
        source_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img], dim=0)
        source_total = source_combined_img.size(0)

        if im_data_trg_ul.numel()>0:
            data_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img, im_data_trg_ul, im_data_bar_trg_ul], dim=0)
            data_combined = feature_extractor(data_combined_img)

            # first classifier is trained only on labeled samples
            # Generate two different outputs of source

            logits_combined = classifier(data_combined)
            logits_source_p = logits_combined[:source_total]
            logits_target = logits_combined[source_total:]

            if 'nologit' not in ablation:
                self._disable_batchnorm_tracking(feature_extractor)
                self._disable_batchnorm_tracking(classifier)
                source_combined = feature_extractor(source_combined_img)
                logits_source_pp = classifier(source_combined)
                self._enable_batchnorm_tracking(feature_extractor)
                self._enable_batchnorm_tracking(classifier)

                lamb = torch.rand_like(logits_source_p).to(device)
                final_logits_source = (lamb*logits_source_p) + ((1-lamb)*logits_source_pp)
            else:
                final_logits_source = logits_source_p
                
            logits_source_weak = final_logits_source[:all_labeled_img.size(0)]
            logits_source_strong = final_logits_source[all_labeled_img.size(0):]
            pseudolabels_target = F.softmax(logits_target, dim=1)
            final_pseudolabels = pseudolabels_target[:im_data_trg_ul.size(0)]
            
            
            pseudolabels_source = F.softmax(logits_source_weak, dim=1)
            pseudolabels_thresh = pseudolabels_source.detach()
            
            row_wise_max, _ = torch.max(pseudolabels_thresh, dim=1)
            final_sum = torch.mean(row_wise_max, 0)
            # define relative confidence threshold
            c_tau = hparams['tau'] * final_sum
            max_values, final_pseudolabels_cls = torch.max(final_pseudolabels, dim=1)
            mask = (max_values >= c_tau).float()
            pseudolabels = final_pseudolabels_cls.detach()
            
            loss1 = self._compute_source_loss(logits_source_weak, logits_source_strong, labels_source)
            
        else:
            source_combined = feature_extractor(source_combined_img)
            logits_source_p = classifier(source_combined)
            
            if 'nologit' not in ablation:
                self._disable_batchnorm_tracking(feature_extractor)
                self._disable_batchnorm_tracking(classifier)
                source_combined = feature_extractor(source_combined_img)
                logits_source_pp = classifier(source_combined)
                self._enable_batchnorm_tracking(feature_extractor)
                self._enable_batchnorm_tracking(classifier)

                lamb = torch.rand_like(logits_source_p).to(device)
                final_logits_source = (lamb*logits_source_p) + ((1-lamb)*logits_source_pp)
            else:
                final_logits_source = logits_source_p
                
            logits_source_weak = final_logits_source[:all_labeled_img.size(0)]
            logits_source_strong = final_logits_source[all_labeled_img.size(0):]
        
            loss1 = self._compute_source_loss(logits_source_weak, logits_source_strong, labels_source)
            
        if im_data_trg_ul.numel()>0:
            # second classifier is trained on labeled on pseudolabeled samples
            logits2_source = classifier2(data_combined[:source_total])
            logits2_target = classifier2(data_combined[source_total:])
            loss2_target = self._compute_target_loss(pseudolabels, logits2_target[im_data_trg_ul.size(0):], mask)
        else:
            logits2_source = classifier2(source_combined)
            loss2_target = torch.tensor(0.0, device=device)
        # compute target loss weight (mu)
        pi = torch.tensor(np.pi, dtype=torch.float).to(device)
        step = torch.tensor(step, dtype=torch.float).to(device)
        mu = 0.5 - torch.cos(torch.minimum(pi, (pi*step) / (warm_steps + 1e-5))) / 2
        
        loss2_source = self._compute_source_loss(logits2_source[:all_labeled_img.size(0)], logits2_source[all_labeled_img.size(0):],
                                                labels_source)

        loss2 = loss2_source + (mu*loss2_target)
        return loss1, loss2
            
class univ_ssda_loss(AdaMatch_loss):
    def __init__(self):
        super().__init__()

    def forward(self, hparams, backbone, classifier, classifier2, im_data_src, im_data_bar_src,
               im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg,
               step, warm_steps, device, ablation=''):
        loss, source_loss, target_loss, loss2 = self.get_losses(hparams, backbone, classifier, classifier2,
                                                               im_data_src, im_data_bar_src, im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul,
                                                               gt_labels_src, gt_labels_trg, step, warm_steps, device, ablation)
        return loss, source_loss, target_loss, loss2

    def get_losses(self, hparams, feature_extractor, classifier, classifier2, im_data_src, im_data_bar_src, 
                   im_data_trg, im_data_bar_trg, im_data_trg_ul, im_data_bar_trg_ul, gt_labels_src, gt_labels_trg,
                  step, warm_steps, device, ablation):
        """
        in this function, source refers to labeled samples, target refers to unlabeled samples
        """
        all_labeled_img = torch.cat([im_data_src, im_data_trg])
        all_labeled_bar_img = torch.cat([im_data_bar_src, im_data_bar_trg])
        labels_source = torch.cat([gt_labels_src, gt_labels_trg])
        
        has_unlabeled = im_data_trg_ul.numel() > 0

        if has_unlabeled:
            data_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img, im_data_trg_ul, im_data_bar_trg_ul], dim=0)
        else:
            data_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img], dim=0)

        source_combined_img = torch.cat([all_labeled_img, all_labeled_bar_img], dim=0)
        data_combined = feature_extractor(data_combined_img)
        source_total = source_combined_img.size(0)

        # second classifier is trained only on labeled samples
        self._disable_batchnorm_tracking(feature_extractor)
        if im_data_trg.numel()>0:
            with torch.no_grad():
                labeled_features_src = feature_extractor(im_data_src).detach()
                labeled_features_trg = feature_extractor(im_data_trg).detach()
            self._enable_batchnorm_tracking(feature_extractor)
            logits2_src = classifier2(labeled_features_src)
            logits2_trg = classifier2(labeled_features_trg)
        else:
            with torch.no_grad():
                labeled_features_src = feature_extractor(im_data_src).detach()
            self._enable_batchnorm_tracking(feature_extractor)
            logits2_src = classifier2(labeled_features_src)
                

        # generate two different outputs of source input
        logits_combined = classifier(data_combined)
        logits_source_p = logits_combined[:source_total]

        if 'nologit' not in ablation:
            self._disable_batchnorm_tracking(feature_extractor)
            self._disable_batchnorm_tracking(classifier)
            source_combined = feature_extractor(source_combined_img)
            logits_source_pp = classifier(source_combined)
            self._enable_batchnorm_tracking(feature_extractor)
            self._enable_batchnorm_tracking(classifier)

            lamb = torch.rand_like(logits_source_p).to(device)
            final_logits_source = (lamb*logits_source_p) + ((1-lamb)*logits_source_pp)
        else:
            final_logits_source = logits_source_p

        logits_source_weak = final_logits_source[:all_labeled_img.size(0)]
        logits_source_strong = final_logits_source[all_labeled_img.size(0):]
        pseudolabels_source = F.softmax(logits_source_weak, dim=1)
        
        # perform relative confidence thresholding based on labeled distribution
        pseudolabels_thresh = pseudolabels_source.detach()
        row_wise_max, _ = torch.max(pseudolabels_thresh, dim=1)
        final_sum = torch.mean(row_wise_max, 0)

        # define relative confidence threshold
        c_tau = hparams['tau'] * final_sum

        if has_unlabeled:
            # softmax for logits of weakly augmented target images
            logits_target = logits_combined[source_total:]
            logits_target_weak = logits_target[:im_data_trg_ul.size(0)]
            pseudolabels_target = F.softmax(logits_target_weak, dim=1)
            # align proportion of shared and private classes in target domain unlabeled distribution according to classifier2
            self._disable_batchnorm_tracking(feature_extractor)
            with torch.no_grad():
                unlabeled_features_trg = feature_extractor(im_data_trg_ul).detach()

            self._enable_batchnorm_tracking(feature_extractor)
            logits2_target_weak = classifier2(unlabeled_features_trg)
            pseudolabels2_target = F.softmax(logits2_target_weak,dim=1)

            if 'nogrpl' not in ablation:
                est_prop_shared = pseudolabels_target.sum(dim=1, keepdim=True).detach()
                est2_prop_shared = pseudolabels2_target.sum(dim=1, keepdim=True).detach()
                pseudolabels_target = pseudolabels_target*est2_prop_shared/(1e-6+est_prop_shared)
                final_pseudolabels = F.normalize(pseudolabels_target, p=1, dim=1)
            else:
                final_pseudolabels = pseudolabels_target

            if 'nopredavg' not in ablation:
                final_pseudolabels = (final_pseudolabels + pseudolabels2_target)/2

            # set src_only_class classes to zero
            final_pseudolabels = F.normalize(final_pseudolabels, p=1, dim=1)

            max_values, final_pseudolabels_cls = torch.max(final_pseudolabels, dim=1)
            mask = (max_values >=c_tau).float()

            # compute loss
            pseudolabels = final_pseudolabels_cls.detach()
            target_loss = self._compute_target_loss(pseudolabels, logits_target[im_data_trg_ul.size(0):],mask)
        else:
            target_loss = torch.tensor(0.0, device=device)

        source_loss = self._compute_source_loss(logits_source_weak, logits_source_strong, labels_source)
        
        # compute target loss weight (mu)
        pi = torch.tensor(np.pi, dtype=torch.float).to(device)
        step = torch.tensor(step, dtype=torch.float).to(device)
        mu = 0.5 - torch.cos(torch.minimum(pi, (pi*step) / (warm_steps + 1e-5))) / 2

        # get total loss
        loss = source_loss + (mu * target_loss)

        # classifier2 loss
        num_samp_src = im_data_src.shape[0]
        num_samp_trg = im_data_trg.shape[0]
        num_samp = num_samp_src + num_samp_trg
        loss2_trg = 0
        if im_data_trg.numel()>0:
            loss2_src = F.cross_entropy(logits2_src, gt_labels_src) # mask this with mask_labeled_src is slightly worse
            loss2_trg = F.cross_entropy(logits2_trg, gt_labels_trg)
        else:
            loss2_src = F.cross_entropy(logits2_src, gt_labels_src)
            
        loss2 = (num_samp_src/num_samp) * loss2_src + (num_samp_trg/num_samp) * loss2_trg

        return loss, source_loss, target_loss, loss2

def sigmoid_rampup(current, rampup_length):
    """ Exponential rampup"""
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current/rampup_length
        return float(np.exp(-5.0 * phase * phase))



class CLDA_InterDomainContrastiveLoss(nn.Module):
    """
    Inter-Domain Contrastive Alignment for CLDA
    Maximizes similarity between centroids of same class from both domains
    """
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, source_centroids, target_centroids):
        """
        Args:
            source_centroids: dict {class_id: centroid_tensor}
            target_centroids: dict {class_id: centroid_tensor}
        """
        if not source_centroids or not target_centroids:
            return torch.tensor(0.0, device=next(iter(source_centroids.values())).device if source_centroids else 'cpu')
        
        device = next(iter(source_centroids.values())).device
        loss = 0.0
        count = 0
        
        # Common classes between source and target
        common_classes = set(source_centroids.keys()) & set(target_centroids.keys())
        
        if len(common_classes) == 0:
            return torch.tensor(0.0, device=device)
        
        for class_id in common_classes:
            src_centroid = source_centroids[class_id]
            tgt_centroid = target_centroids[class_id]
            
            # Normalize centroids
            src_centroid = F.normalize(src_centroid, dim=0)
            tgt_centroid = F.normalize(tgt_centroid, dim=0)
            
            # Positive pair: same class centroids from different domains
            pos_sim = F.cosine_similarity(src_centroid.unsqueeze(0), tgt_centroid.unsqueeze(0))
            
            # Negative pairs: different class centroids
            neg_sims = []
            for other_class in common_classes:
                if other_class != class_id:
                    other_src = F.normalize(source_centroids[other_class], dim=0)
                    other_tgt = F.normalize(target_centroids[other_class], dim=0)
                    
                    neg_sim_src = F.cosine_similarity(src_centroid.unsqueeze(0), other_src.unsqueeze(0))
                    neg_sim_tgt = F.cosine_similarity(tgt_centroid.unsqueeze(0), other_tgt.unsqueeze(0))
                    neg_sims.extend([neg_sim_src, neg_sim_tgt])
            
            if neg_sims:
                neg_sims = torch.stack(neg_sims)
                # InfoNCE loss
                numerator = torch.exp(pos_sim / self.temperature)
                denominator = numerator + torch.sum(torch.exp(neg_sims / self.temperature))
                contrastive_loss = -torch.log(numerator / (denominator + 1e-8))
                loss += contrastive_loss
                count += 1
        
        return loss / count if count > 0 else torch.tensor(0.0, device=device)

class CLDA_InstanceContrastiveLoss(nn.Module):
    """
    Instance Contrastive Alignment for CLDA
    Maximizes similarity between original and strongly augmented unlabeled target images
    """
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, original_features, augmented_features):
        """
        Args:
            original_features: tensor of shape [batch_size, feature_dim]
            augmented_features: tensor of shape [batch_size, feature_dim]
        """
        if original_features.size(0) == 0 or augmented_features.size(0) == 0:
            return torch.tensor(0.0, device=original_features.device if original_features.numel() > 0 else 'cpu')
        
        # Normalize features
        original_features = F.normalize(original_features, dim=1)
        augmented_features = F.normalize(augmented_features, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.mm(original_features, augmented_features.t()) / self.temperature
        
        # Positive pairs are on the diagonal
        batch_size = original_features.size(0)
        labels = torch.arange(batch_size, device=original_features.device)
        
        # Cross-entropy loss for contrastive learning
        loss = F.cross_entropy(similarity_matrix, labels)
        
        return loss

class CLDA_ContrastiveLoss(nn.Module):
    """
    Enhanced Contrastive Loss for CLDA
    Combines supervised contrastive learning with domain alignment
    """
    def __init__(self, temperature=0.5, base_temperature=0.5):
        super().__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        
    def forward(self, src_features, trg_features, src_labels, trg_labels):
        """
        Args:
            src_features: source domain features [batch_size, feature_dim]
            trg_features: target domain features [batch_size, feature_dim]
            src_labels: source domain labels [batch_size]
            trg_labels: target domain labels [batch_size]
        """
        device = src_features.device
        batch_size = src_features.size(0)
        
        if src_features.size(0) == 0 or trg_features.size(0) == 0:
            return torch.tensor(0.0, device=device)
        
        # Normalize features
        src_features = F.normalize(src_features, dim=1)
        trg_features = F.normalize(trg_features, dim=1)
        
        # Combine features and labels
        all_features = torch.cat([src_features, trg_features], dim=0)
        all_labels = torch.cat([src_labels, trg_labels], dim=0)
        
        # Compute similarity matrix
        anchor_dot_contrast = torch.div(
            torch.matmul(all_features, all_features.T),
            self.temperature
        )
        
        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        # Create mask for positive pairs (same class)
        mask = torch.eq(all_labels.unsqueeze(1), all_labels.unsqueeze(0)).float()
        
        # Remove self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(all_features.size(0)).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # Compute log probabilities
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)
        
        # Compute mean of log-likelihood over positive pairs
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        
        # Loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        
        # Only compute loss for samples that have positive pairs
        valid_mask = mask.sum(1) > 0
        if valid_mask.sum() > 0:
            loss = loss[valid_mask].mean()
        else:
            loss = torch.tensor(0.0, device=device)
            
        return loss
