import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class EnhancedMLPWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dims=[512, 256], output_dim=64, num_classes=5, alpha_mix=0.3):
        super(EnhancedMLPWithAttention, self).__init__()
        self.num_classes = num_classes
        self.output_dim = output_dim
        self.kernels = nn.Parameter(self.initialize_kernels_with_margin(
            num_classes, output_dim, margin=1.414, max_iterations=1000), requires_grad=True)
        logger.info(f"Initialized {num_classes} kernels with shape: {self.kernels.shape}")

        layers = []
        current_dim = input_dim
        for i in range(len(hidden_dims)):
            layers.append(nn.Linear(current_dim, hidden_dims[i]))
            layers.append(nn.LeakyReLU())
            layers.append(nn.BatchNorm1d(hidden_dims[i]))
            layers.append(nn.Dropout(0.2))
            if len(hidden_dims) == 2 and i == 0:
                 layers.append(nn.LeakyReLU())
            current_dim = hidden_dims[i]
        layers.append(nn.Linear(current_dim, output_dim))
        self.shared_mlp = nn.Sequential(*layers)

        self.query_fc = nn.Linear(input_dim, output_dim)
        self.key_fc = nn.Linear(output_dim, output_dim)
        self.value_fc = nn.Linear(output_dim, output_dim)
        self.alpha = nn.Parameter(torch.tensor(alpha_mix))
        logger.info(f"EnhancedMLPWithAttention alpha (mix parameter) initialized to: {alpha_mix}")


    def initialize_kernels_with_margin(self, num_classes, embedding_dim, margin, max_iterations=1000):
        logger.info(f"Initializing kernels (Num Classes: {num_classes}, Dim: {embedding_dim}, Margin: {margin})...")
        kernels = torch.randn(num_classes, embedding_dim) * 0.1
        kernels = nn.functional.normalize(kernels, p=2, dim=1)
        initial_distances = torch.cdist(kernels, kernels, p=2)
        logger.debug(f"Initial kernel distances (before adjustment):\n{initial_distances}")

        for iteration in range(max_iterations):
            distances = torch.cdist(kernels, kernels, p=2)
            too_close = False
            adjustment_made = False
            max_adjustment = 0.0

            for i in range(num_classes):
                for j in range(i + 1, num_classes):
                    dist_ij = distances[i, j]
                    if dist_ij < margin:
                        too_close = True
                        adjustment_made = True
                        push_amount = (margin - dist_ij) / 2.0
                        direction = nn.functional.normalize(kernels[j] - kernels[i], p=2, dim=0)
                        adjustment = push_amount * direction
                        max_adjustment = max(max_adjustment, adjustment.abs().max().item())

                        kernels.data[i] -= adjustment
                        kernels.data[j] += adjustment
                        kernels.data[i] = nn.functional.normalize(kernels.data[i], p=2, dim=0)
                        kernels.data[j] = nn.functional.normalize(kernels.data[j], p=2, dim=0)

            if iteration % 100 == 0 or not too_close:
                current_distances = torch.cdist(kernels, kernels, p=2)
                min_dist = current_distances[np.triu_indices(num_classes, k=1)].min().item() if num_classes > 1 else float('inf')
                logger.debug(f"Iter {iteration+1}: Min distance={min_dist:.4f}, Max adjustment={max_adjustment:.4f}")

            if not too_close:
                logger.info(f"Kernels initialized successfully after {iteration + 1} iterations.")
                final_distances = torch.cdist(kernels, kernels, p=2)
                logger.info(f"Final kernel distances:\n{final_distances}")
                return kernels
            elif not adjustment_made and too_close:
                logger.warning(f"Stuck in kernel initialization at iter {iteration+1}. Min distance: {min_dist:.4f}")
                break

        logger.warning(f"Kernel initialization may not have fully converged after {max_iterations} iterations.")
        final_distances = torch.cdist(kernels, kernels, p=2)
        logger.info(f"Final kernel distances (may be suboptimal):\n{final_distances}")
        return kernels


    def forward(self, x):
        features = self.shared_mlp(x)

        Q = self.query_fc(x)
        K = self.key_fc(self.kernels)
        K = K.unsqueeze(0).expand(x.size(0), -1, -1)
        V = self.value_fc(self.kernels)
        V = V.unsqueeze(0).expand(x.size(0), -1, -1)

        attention_scores = torch.bmm(Q.unsqueeze(1), K.transpose(1, 2)).squeeze(1) / torch.sqrt(
            torch.tensor(K.shape[2], dtype=torch.float32, device=x.device))

        attention_weights = F.softmax(attention_scores, dim=1)

        attention_output = torch.bmm(attention_weights.unsqueeze(1), V).squeeze(1)
        attention_output = nn.Dropout(0.2)(attention_output)

        enhanced_features = self.alpha * features + (1 - self.alpha) * attention_output

        return {
            "enhanced_features": enhanced_features,
            "features": features,
            "attention_output": attention_output,
            "attention_scores": attention_scores,
            "kernels": self.kernels,
        }

class SupervisedContrastiveLossWithKernel(nn.Module):
    def __init__(self, num_classes, embedding_dim=64, temperature=0.1, lambda_offset=0.1, class_weights=None,
                 w_contrastive=1.0, w_offset=1.0, w_class=3.0, w_ortho=5.0,
                 enable_magnitude_loss=False, w_magnitude_value=1.0):
        super().__init__()
        self.temperature = temperature
        self.lambda_offset = lambda_offset
        self.class_weights = class_weights
        if self.class_weights is None:
            logger.warning("Class weights not provided to loss function, using uniform weights.")
            self.class_weights = torch.ones(num_classes)

        self.eps = 1e-8

        self.w_contrastive = w_contrastive
        self.w_offset = w_offset
        self.w_class = w_class
        self.w_ortho = w_ortho
        self.w_magnitude_eff = w_magnitude_value if enable_magnitude_loss else 0.0
        if enable_magnitude_loss:
            logger.info(f"Magnitude loss enabled with weight: {self.w_magnitude_eff}")
        else:
            logger.info("Magnitude loss disabled.")


    def compute_combined_offset_penalty(self, features, labels, kernels):
        device = features.device
        batch_size, embedding_dim = features.shape
        num_classes, kernel_dim = kernels.shape
        assert embedding_dim == kernel_dim, f"Embedding dim {embedding_dim} != Kernel dim {kernel_dim}"

        labels = labels.squeeze()
        if labels.ndim == 0:
            labels = labels.unsqueeze(0)

        features = F.normalize(features, p=2, dim=1)
        kernels = F.normalize(kernels, p=2, dim=1)

        total_penalty = 0.0
        weights = self.class_weights.to(device)

        for cls in range(num_classes):
            mask = (labels == cls)
            if mask.sum() == 0:
                continue

            class_features = features[mask]
            kernel = kernels[cls]

            distances_to_own_kernel = torch.norm(class_features - kernel.unsqueeze(0), p=2, dim=1)
            delta = 0.63
            in_class_penalty = torch.relu(distances_to_own_kernel - delta) ** 2

            if num_classes > 1:
                other_kernels = kernels[torch.arange(num_classes, device=device) != cls]
                distances_to_other_kernels = torch.cdist(class_features, other_kernels, p=2)
                min_dist_to_others, _ = distances_to_other_kernels.min(dim=1)

                margin = 2.0
                cross_class_penalty = torch.relu(distances_to_own_kernel - min_dist_to_others + margin) ** 2
            else:
                cross_class_penalty = torch.tensor(0.0, device=device)

            total_penalty += weights[cls] * (in_class_penalty.sum() + cross_class_penalty.sum())

        total_penalty /= batch_size
        return total_penalty

    def compute_contrastive_loss(self, features, labels, class_weights):
        device = features.device
        batch_size = features.shape[0]

        labels = labels.squeeze()
        if labels.ndim == 0:
            labels = labels.unsqueeze(0)

        features = F.normalize(features, p=2, dim=1)
        similarity_matrix = torch.matmul(features, features.T) / self.temperature

        same_class_mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float().to(device)
        same_class_mask.fill_diagonal_(0)

        similarity_matrix_stable = similarity_matrix - similarity_matrix.max(dim=1, keepdim=True)[0].detach()
        log_prob = similarity_matrix_stable - torch.logsumexp(similarity_matrix_stable, dim=1, keepdim=True)

        positive_log_prob = (log_prob * same_class_mask).sum(dim=1)
        positive_pairs_count = torch.clamp(same_class_mask.sum(dim=1), min=self.eps)
        weights_tensor = class_weights.to(device)
        sample_weights = weights_tensor[labels]

        mean_positive_log_prob_per_sample = positive_log_prob / positive_pairs_count
        weighted_loss = sample_weights * mean_positive_log_prob_per_sample
        contrastive_loss = -weighted_loss.mean()

        different_class_mask = 1 - same_class_mask
        different_class_mask.fill_diagonal_(0) 
        k_desired = 10  
        hard_negative_loss = torch.tensor(0.0, device=device) 

        num_negatives_per_sample = different_class_mask.sum(dim=1).long()
        k_safe = 0 
        if num_negatives_per_sample.numel() > 0: 
            min_negatives_in_batch = num_negatives_per_sample.min().item()
            if min_negatives_in_batch > 0: 
                k_safe = min(k_desired, min_negatives_in_batch)
        
        if k_safe > 0:
            masked_similarity_for_negatives = similarity_matrix * different_class_mask
            
            _, hardest_negatives_indices = torch.topk(masked_similarity_for_negatives, k=k_safe, dim=1)
            
            hard_negative_mask = torch.zeros_like(different_class_mask).scatter_(1, hardest_negatives_indices, 1.0)
            
            hard_negative_log_prob_sum_per_sample = (log_prob * hard_negative_mask).sum(dim=1)
            
            actual_hard_negative_counts = hard_negative_mask.sum(dim=1) 
            actual_hard_negative_counts = torch.clamp(actual_hard_negative_counts, min=self.eps) 

            mean_hard_negative_log_prob = hard_negative_log_prob_sum_per_sample / actual_hard_negative_counts
            mean_hard_negative_log_prob[torch.isnan(mean_hard_negative_log_prob) | torch.isinf(mean_hard_negative_log_prob)] = 0.0

            hard_negative_loss = (sample_weights * mean_hard_negative_log_prob).mean()

        alpha_hnm = 0.8 # Renamed from alpha to avoid confusion
        beta_hnm = 0.2  # Renamed from beta
        total_contrastive_loss = alpha_hnm * contrastive_loss + beta_hnm * hard_negative_loss
        logger.debug(f"Contrastive Loss: {total_contrastive_loss:.4f}")
        return total_contrastive_loss

    def compute_orthogonality_loss(self, features, attention_output, labels, margin=0.5):
        device = features.device
        batch_size = features.shape[0]

        labels = labels.squeeze()
        if labels.ndim == 0:
            labels = labels.unsqueeze(0)

        features = F.normalize(features, p=2, dim=1)
        attention_output = F.normalize(attention_output, p=2, dim=1)

        cosine_similarity = (features * attention_output).sum(dim=1)

        orthogonality_loss = torch.relu(torch.abs(cosine_similarity) - margin)

        weights = self.class_weights.to(device)
        sample_weights = weights[labels]
        weighted_loss = orthogonality_loss * sample_weights

        mean_loss = weighted_loss.mean()
        return mean_loss

    def compute_focal_loss(self, inputs, targets):
        targets = targets.squeeze().long()
        if targets.ndim == 0:
            targets = targets.unsqueeze(0)

        alpha_focal = self.class_weights.clone().detach().to(inputs.device)
        gamma_focal = 2.0

        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss_unweighted = ((1 - pt) ** gamma_focal) * ce_loss

        alpha_t = alpha_focal[targets]
        focal_loss_weighted = alpha_t * focal_loss_unweighted

        mean_loss = focal_loss_weighted.mean()
        return mean_loss

    def compute_magnitude_loss(self, attention_output, labels, kernels):
        device = attention_output.device

        weights = self.class_weights.to(device)

        output_magnitudes = torch.norm(attention_output, p=2, dim=1, keepdim=True)

        magnitude_loss = torch.tensor(0.0, device=device) # Initialize
        if labels.dim() == 1: # Standard classification
            kernel_magnitudes = torch.norm(kernels, p=2, dim=1, keepdim=True)
            one_hot = F.one_hot(labels, num_classes=kernels.shape[0]).float()
            target_magnitudes = torch.matmul(one_hot, kernel_magnitudes)
            magnitude_loss = (weights[labels] * torch.abs(output_magnitudes.squeeze() - target_magnitudes.squeeze())).mean()

        elif labels.dim() == 2: # Ordinal classification with intensity
            intensities = labels[:, 1].unsqueeze(1).float() # Ensure float
            class_labels = labels[:, 0].long() # Ensure long for indexing

            kernel_magnitudes = torch.norm(kernels, p=2, dim=1, keepdim=True)
            one_hot = F.one_hot(class_labels, num_classes=kernels.shape[0]).float()
            target_magnitudes_base = torch.matmul(one_hot, kernel_magnitudes)
            target_magnitudes_scaled = target_magnitudes_base * intensities
            magnitude_loss = (weights[class_labels] * torch.abs(output_magnitudes.squeeze() - target_magnitudes_scaled.squeeze())).mean()
        else:
            raise ValueError("labels must be a 1D or 2D tensor for magnitude loss computation.")

        kernel_magnitude_loss = 0.0
        magnitude_factor = 0.67
        num_kernels, kernel_dim = kernels.shape
        if num_kernels > 1:
            kernel_norms = torch.norm(kernels, p=2, dim=1, keepdim=True)
            magnitude_diffs = torch.abs(kernel_norms - kernel_norms.T)
            triu_indices = torch.triu_indices(num_kernels, num_kernels, offset=1)
            magnitude_diffs_pairs = magnitude_diffs[triu_indices[0], triu_indices[1]]
            
            target_diffs_indices = torch.abs(torch.arange(num_kernels, device=device).unsqueeze(1) - torch.arange(num_kernels, device=device).unsqueeze(0)).float()
            target_diffs_pairs = target_diffs_indices[triu_indices[0], triu_indices[1]] * magnitude_factor
            
            if magnitude_diffs_pairs.numel() > 0: # Ensure there are pairs to compute loss on
                 kernel_magnitude_loss = torch.max(torch.zeros_like(magnitude_diffs_pairs), torch.abs(magnitude_diffs_pairs - target_diffs_pairs)).mean()
            else:
                 kernel_magnitude_loss = torch.tensor(0.0, device=device)

        return magnitude_loss + kernel_magnitude_loss

    def forward(self, model_output, labels):
        enhanced_features = model_output["enhanced_features"]
        features = model_output["features"]
        attention_output = model_output["attention_output"]
        attention_scores = model_output["attention_scores"]
        kernels = model_output["kernels"]

        contrastive_loss_val = self.compute_contrastive_loss(enhanced_features, labels, self.class_weights)
        offset_loss_val = self.compute_combined_offset_penalty(enhanced_features, labels, kernels)
        classification_loss_val = self.compute_focal_loss(attention_scores, labels)
        orthogonality_loss_val = self.compute_orthogonality_loss(features, attention_output, labels, margin=0.3)
        
        magnitude_loss_val = torch.tensor(0.0, device=enhanced_features.device)
        if self.w_magnitude_eff > 0:
            magnitude_loss_val = self.compute_magnitude_loss(attention_output, labels, kernels)

        total_loss = (
            self.w_contrastive * contrastive_loss_val
            + self.w_offset * offset_loss_val
            + self.w_class * classification_loss_val
            + self.w_ortho * orthogonality_loss_val
            + self.w_magnitude_eff * magnitude_loss_val
        )

        loss_components = {
            "total": total_loss.item(),
            "contrastive": contrastive_loss_val.item() * self.w_contrastive,
            "offset": offset_loss_val.item() * self.w_offset,
            "classification": classification_loss_val.item() * self.w_class,
            "orthogonality": orthogonality_loss_val.item() * self.w_ortho,
            "magnitude": magnitude_loss_val.item() * self.w_magnitude_eff,
        }

        return total_loss, loss_components
