
import numpy as np
import torch
from scipy.ndimage import binary_erosion
from scipy.ndimage import label as ndi_label
from collections import defaultdict


def weight_policy_weighter(weights, thresh):
    coeff = weights.shape[0]
    result = torch.where(weights > thresh, weights, torch.tensor(0.0).cuda())
    if result.sum() == 0:
        return torch.ones_like(weights).cuda()
    return coeff * result / result.sum()

class LabelStatistics:
    def __init__(self):
        self.label_stats = defaultdict(lambda: {
            'mean_volume': 0, 'std_volume': 0,
            'mean_num_ccs':0, 'std_num_ccs':0,
            'mean_surface_area_to_volume': 0, 'std_surface_area_to_volume': 0,
            'volumes': [], 'surface_ratios': [],"num_ccs":[], 'count': 0
        })

    def compute_surface_area(self, binary_volume):
        # Calculate surface area using binary erosion (boundary voxels)
        eroded = binary_erosion(binary_volume)
        surface_voxels = binary_volume & ~eroded
        return np.sum(surface_voxels)
    def calc_metrics(self, label):
        output = {}
        unique_labels = np.unique(label)
        for class_id in unique_labels:
            if class_id == 0:
                continue
            binary_volume = (label == class_id).astype(np.int32)
            volume_size = np.sum(binary_volume)
            num_ccs = ndi_label(binary_volume)[1]
            output[class_id] = {"volume": volume_size, "num_ccs": num_ccs}
        return output
            
        
    def calculate_statistics(self, train_loader):
        """
        Calculates and stores label statistics over the training set.
        Args:
            train_loader (DataLoader): DataLoader for the training set with (volume, label) tuples.
        """
        for outputs in train_loader:
            if len(outputs) == 2:  # Case: (volume, label)
                _, label = outputs
            elif len(outputs) == 4:  # Case: (volume, label, additional1, additional2)
                _, label, _, _ = outputs
            else:
                raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
        # for _, label, _, _ in train_loader:
            label = label.numpy()
            unique_labels = np.unique(label)
            for class_id in unique_labels:
                if class_id == 0:  # Skip background if 0 is background
                    continue
                # Isolate the current class and calculate volume and surface area
                binary_volume = (label == class_id).astype(np.int32)
                volume_size = np.sum(binary_volume)
                # num ccs
                num_ccs = ndi_label(binary_volume)[1]

                # Store volume and surface ratio for mean and std calculations
                stats = self.label_stats[int(class_id)]
                stats['volumes'].append(volume_size)
                stats['num_ccs'].append(num_ccs)
                stats['count'] += 1

        # Calculate mean and standard deviation for each class
        for class_id, stats in self.label_stats.items():
            stats['mean_num_ccs'] = np.mean(stats['num_ccs'])
            stats['std_num_ccs'] = np.std(stats['num_ccs'])
            stats['mean_volume'] = np.mean(stats['volumes'])
            stats['std_volume'] = np.std(stats['volumes'])

    def closeness(self, x, mean, std):
        return (std + 1e-6) / (np.abs(x - mean)+1e-6)

    def likelihood_weight(self, predicted_label):
        """
        Computes likelihood-based weight for a predicted label based on precomputed statistics.
        Args:
            predicted_label (torch.Tensor): Predicted label tensor with shape (D, H, W).
        Returns:
            torch.Tensor: Weight tensor with the same shape as predicted_label, with weights based on likelihood.
        """
        combined_likelihood = 1
        if predicted_label.sum() == 0:
            return 0
        label_metrics = self.calc_metrics(predicted_label)
        # print(label_metrics)
        for class_id, stats in self.label_stats.items():                     
            # Calculate likelihood for volume and surface ratio independently
            if int(class_id) in label_metrics:
                volume_closeness = self.closeness(label_metrics[int(class_id)]['volume'], stats['mean_volume'], stats['std_volume'])
                num_ccs_closeness = self.closeness(label_metrics[int(class_id)]['num_ccs'], stats['mean_num_ccs'], stats['std_num_ccs'])
            else:
                print("Class not found in label metrics")
                volume_closeness = 1
                num_ccs_closeness = 1

                # surface_ratio_closeness = self.closeness(label_metrics[class_id]['ratio_surface_volume'], mean_surface_ratio, std_surface_ratio)
            # Combined likelihood as a product of individual likelihoods
            combined_likelihood *= min(volume_closeness, 1) * min(num_ccs_closeness, 1)
            # print(volume_closeness, num_ccs_closeness, combined_likelihood)

        return combined_likelihood


    # def likelihood_weight(self, predicted_label):
    #     """
    #     Computes likelihood-based weight for a predicted label based on precomputed statistics.
    #     Args:
    #         predicted_label (torch.Tensor): Predicted label tensor with shape (D, H, W).
    #     Returns:
    #         torch.Tensor: Weight tensor with the same shape as predicted_label, with weights based on likelihood.
    #     """
    #     combined_likelihood = 1
    #     if predicted_label.sum() == 0:
    #         return 0
    #     label_metrics = self.calc_metrics(predicted_label)
    #     # print(label_metrics)
    #     for class_id, stats in self.label_stats.items():
    #         # Calculate likelihood for volume and surface ratio independently
    #         volume_closeness = self.closeness(
    #             label_metrics.get(int(class_id), {}).get('volume', 0), 
    #             stats['mean_volume'], 
    #             stats['std_volume']
    #         )
    #         # surface_ratio_closeness = self.closeness(label_metrics[class_id]['ratio_surface_volume'], mean_surface_ratio, std_surface_ratio)
    #         num_ccs_closeness = self.closeness(
    #             label_metrics.get(int(class_id), {}).get('num_ccs', 0), 
    #             stats['mean_num_ccs'], 
    #             stats['std_num_ccs']
    #         )
    #         # Combined likelihood as a product of individual likelihoods
    #         combined_likelihood *= min(volume_closeness, 1) * min(num_ccs_closeness, 1)
    #         # print(volume_closeness, num_ccs_closeness, combined_likelihood)

    #     return combined_likelihood
