import torch
import torch.nn.functional as F
import numpy as np
from abc import ABC, abstractmethod


class BaseLoss(ABC):
    """Abstract base class for all loss functions"""

    @abstractmethod
    def compute(self, *args, **kwargs):
        pass

    @property
    @abstractmethod
    def name(self):
        pass

    @property
    @abstractmethod
    def requires_rbm_params(self):
        pass


class ReconstructionLoss(BaseLoss):
    """Standard MSE reconstruction loss"""

    def compute(self, v_reconstructed, node_pairs, **kwargs):
        v_recon_clipped = torch.clamp(v_reconstructed, -10, 10)
        node_pairs_clipped = torch.clamp(node_pairs, -10, 10)
        loss = F.mse_loss(v_recon_clipped, node_pairs_clipped)

        if torch.isnan(loss) or torch.isinf(loss):
            return torch.tensor(0.0, device=v_reconstructed.device, requires_grad=True)
        return loss

    @property
    def name(self):
        return "reconstruction"

    @property
    def requires_rbm_params(self):
        return False


class NormalizedReconstructionLoss(BaseLoss):
    """Cosine similarity based reconstruction loss (most stable)"""

    def compute(self, v_reconstructed, node_pairs, **kwargs):
        v_recon_norm = F.normalize(v_reconstructed, p=2, dim=-1)
        node_pairs_norm = F.normalize(node_pairs, p=2, dim=-1)

        cosine_sim = F.cosine_similarity(
            v_recon_norm.view(-1, v_recon_norm.size(-1)),
            node_pairs_norm.view(-1, node_pairs_norm.size(-1)),
            dim=-1
        )

        return (1 - cosine_sim).mean()

    @property
    def name(self):
        return "normalized_reconstruction"

    @property
    def requires_rbm_params(self):
        return False


class ContrastiveDivergenceLoss(BaseLoss):
    """Stable Contrastive Divergence loss"""

    def compute(self, v_reconstructed, node_pairs, rbm_weights, visible_bias, hidden_bias, **kwargs):
        device = v_reconstructed.device
        n_heads = rbm_weights.size(0)

        # Clip inputs for stability
        v_reconstructed = torch.clamp(v_reconstructed, -3, 3)
        node_pairs = torch.clamp(node_pairs, -3, 3)
        rbm_weights = torch.clamp(rbm_weights, -3, 3)
        visible_bias = torch.clamp(visible_bias, -3, 3)
        hidden_bias = torch.clamp(hidden_bias, -3, 3)

        total_energy_diff = []

        for head in range(n_heads):
            v_pos = node_pairs[head]
            v_neg = v_reconstructed[head]

            # Compute stable energy difference
            pos_vb = torch.sum(v_pos * visible_bias[head].unsqueeze(0), dim=1)
            neg_vb = torch.sum(v_neg * visible_bias[head].unsqueeze(0), dim=1)

            h_pos_act = torch.clamp(hidden_bias[head].unsqueeze(0) + torch.matmul(v_pos, rbm_weights[head]), -5, 5)
            h_neg_act = torch.clamp(hidden_bias[head].unsqueeze(0) + torch.matmul(v_neg, rbm_weights[head]), -5, 5)

            h_pos = torch.sigmoid(h_pos_act)
            h_neg = torch.sigmoid(h_neg_act)

            pos_hb = torch.sum(h_pos * hidden_bias[head].unsqueeze(0), dim=1)
            neg_hb = torch.sum(h_neg * hidden_bias[head].unsqueeze(0), dim=1)

            energy_diff = (pos_vb + pos_hb).mean() - (neg_vb + neg_hb).mean()
            total_energy_diff.append(energy_diff)

        loss = torch.stack(total_energy_diff).mean()
        loss = torch.clamp(loss, -10, 10)

        if torch.isnan(loss) or torch.isinf(loss):
            return torch.tensor(0.0, device=device, requires_grad=True)
        return loss

    @property
    def name(self):
        return "contrastive_divergence"

    @property
    def requires_rbm_params(self):
        return True


class FreeEnergyLoss(BaseLoss):
    """Stable Free Energy loss"""

    def compute(self, v_reconstructed, node_pairs, rbm_weights, visible_bias, hidden_bias, **kwargs):
        device = v_reconstructed.device
        n_heads = rbm_weights.size(0)

        # Clip inputs
        v_reconstructed = torch.clamp(v_reconstructed, -2, 2)
        node_pairs = torch.clamp(node_pairs, -2, 2)
        rbm_weights = torch.clamp(rbm_weights, -2, 2)
        visible_bias = torch.clamp(visible_bias, -2, 2)
        hidden_bias = torch.clamp(hidden_bias, -2, 2)

        def stable_softplus(x):
            return torch.where(x > 0, x + torch.log(1 + torch.exp(-x)), torch.log(1 + torch.exp(x)))

        free_energy_diff = []

        for head in range(n_heads):
            v_pos = node_pairs[head]
            v_neg = v_reconstructed[head]

            vb_pos = torch.sum(v_pos * visible_bias[head].unsqueeze(0), dim=1)
            vb_neg = torch.sum(v_neg * visible_bias[head].unsqueeze(0), dim=1)

            hidden_input_pos = torch.clamp(hidden_bias[head].unsqueeze(0) + torch.matmul(v_pos, rbm_weights[head]), -5,
                                           5)
            hidden_input_neg = torch.clamp(hidden_bias[head].unsqueeze(0) + torch.matmul(v_neg, rbm_weights[head]), -5,
                                           5)

            hidden_contrib_pos = torch.sum(stable_softplus(hidden_input_pos), dim=1)
            hidden_contrib_neg = torch.sum(stable_softplus(hidden_input_neg), dim=1)

            fe_pos = -(vb_pos + hidden_contrib_pos)
            fe_neg = -(vb_neg + hidden_contrib_neg)

            free_energy_diff.append(fe_pos.mean() - fe_neg.mean())

        loss = torch.stack(free_energy_diff).mean()
        loss = torch.clamp(loss, -20, 20)

        if torch.isnan(loss) or torch.isinf(loss):
            return torch.tensor(0.0, device=device, requires_grad=True)
        return loss

    @property
    def name(self):
        return "free_energy"

    @property
    def requires_rbm_params(self):
        return True


class KLDivergenceLoss(BaseLoss):
    """KL divergence between original and reconstructed distributions"""

    def compute(self, v_reconstructed, node_pairs, **kwargs):
        # Normalize to probability distributions
        p = torch.softmax(node_pairs, dim=-1)
        q = torch.softmax(v_reconstructed, dim=-1)

        kl_div = F.kl_div(torch.log(q + 1e-10), p, reduction='batchmean')

        if torch.isnan(kl_div) or torch.isinf(kl_div):
            return torch.tensor(0.0, device=v_reconstructed.device, requires_grad=True)
        return kl_div

    @property
    def name(self):
        return "kl_divergence"

    @property
    def requires_rbm_params(self):
        return False


class CrossEntropyLoss(BaseLoss):
    """Cross-entropy loss similar to NLL but for continuous values"""

    def compute(self, v_reconstructed, node_pairs, **kwargs):
        # Convert to probabilities
        p_target = torch.softmax(node_pairs, dim=-1)
        log_p_pred = torch.log_softmax(v_reconstructed, dim=-1)

        # Cross entropy: -sum(p * log(q))
        ce_loss = -torch.sum(p_target * log_p_pred, dim=-1).mean()

        if torch.isnan(ce_loss) or torch.isinf(ce_loss):
            return torch.tensor(0.0, device=v_reconstructed.device, requires_grad=True)
        return ce_loss

    @property
    def name(self):
        return "cross_entropy"

    @property
    def requires_rbm_params(self):
        return False


class HuberLoss(BaseLoss):
    """Huber loss for robust reconstruction (less sensitive to outliers)"""

    def compute(self, v_reconstructed, node_pairs, delta=1.0, **kwargs):
        diff = v_reconstructed - node_pairs
        abs_diff = torch.abs(diff)

        # Huber loss: quadratic for small errors, linear for large errors
        loss = torch.where(abs_diff <= delta,
                           0.5 * diff ** 2,
                           delta * (abs_diff - 0.5 * delta))

        return loss.mean()

    @property
    def name(self):
        return "huber"

    @property
    def requires_rbm_params(self):
        return False


class SupervisedContrastiveLoss(BaseLoss):
    """Supervised contrastive reconstruction loss that incorporates node labels.
    Fixed to handle non-contiguous tensors using contiguous().view or reshape.
    """

    def compute(self, v_reconstructed, node_pairs, **kwargs):
        labels = kwargs.get('labels')
        edge_index = kwargs.get('edge_index')
        temperature = kwargs.get('temperature', 0.1)

        if labels is None or edge_index is None:
            raise ValueError("SupervisedContrastiveLoss requires 'labels' and 'edge_index' in kwargs")

        device = v_reconstructed.device
        row, col = edge_index

        # Ensure tensors are contiguous before view
        v_recon_flat = v_reconstructed.contiguous().view(-1, v_reconstructed.size(-1))
        node_pairs_flat = node_pairs.contiguous().view(-1, node_pairs.size(-1))

        # Normalize embeddings
        v_recon_norm = F.normalize(v_recon_flat, p=2, dim=-1)
        node_pairs_norm = F.normalize(node_pairs_flat, p=2, dim=-1)

        # Compute similarity matrix
        sim_matrix = torch.exp(torch.matmul(v_recon_norm, node_pairs_norm.T) / temperature)

        num_heads = v_reconstructed.size(0)
        num_edges = edge_index.size(1)
        label_sim = (labels[row] == labels[col]).float().repeat(num_heads)

        pos_mask = torch.diag(label_sim)
        positives = torch.sum(sim_matrix * pos_mask, dim=1)
        total = torch.sum(sim_matrix, dim=1)

        loss = -torch.log(positives / (total + 1e-10) + 1e-10)
        loss = torch.where(label_sim > 0, loss, torch.zeros_like(loss)).mean()

        if torch.isnan(loss) or torch.isinf(loss):
            return torch.tensor(0.0, device=device, requires_grad=True)

        return torch.clamp(loss, 0, 10)

    @property
    def name(self):
        return "supervised_contrastive"

    @property
    def requires_rbm_params(self):
        return False


class FocalClassificationLoss(BaseLoss):
    """Focal loss for the primary classification task, handling class imbalance.
    Requires 'out' (model logits [N, C]) and 'targets' (labels [N]) in kwargs.
    """

    def compute(self, v_reconstructed=None, node_pairs=None, **kwargs):
        out = kwargs.get('out')
        targets = kwargs.get('targets')
        alpha = kwargs.get('alpha', 0.25)
        gamma = kwargs.get('gamma', 2.0)

        if out is None or targets is None:
            raise ValueError("FocalClassificationLoss requires 'out' and 'targets' in kwargs")

        device = out.device

        # Compute cross-entropy per sample
        ce_loss = F.cross_entropy(out, targets, reduction='none')

        # Probability of true class
        pt = torch.exp(-ce_loss)

        # Focal loss
        focal_loss = alpha * (1 - pt) ** gamma * ce_loss

        loss = focal_loss.mean()

        if torch.isnan(loss) or torch.isinf(loss):
            return torch.tensor(0.0, device=device, requires_grad=True)

        return loss

    @property
    def name(self):
        return "focal_classification"

    @property
    def requires_rbm_params(self):
        return False


class GraphSmoothingLoss(BaseLoss):
    """Graph smoothing loss to encourage similar predictions for connected nodes.
    Requires 'out' (model logits [N, C]) and 'edge_index' in kwargs.
    """

    def compute(self, v_reconstructed=None, node_pairs=None, **kwargs):
        out = kwargs.get('out')
        edge_index = kwargs.get('edge_index')

        if out is None or edge_index is None:
            raise ValueError("GraphSmoothingLoss requires 'out' and 'edge_index' in kwargs")

        device = out.device
        row, col = edge_index

        # Difference in predictions over edges
        pred_diff = torch.norm(out[row] - out[col], dim=1, p=2)

        loss = pred_diff.mean()

        if torch.isnan(loss) or torch.isinf(loss):
            return torch.tensor(0.0, device=device, requires_grad=True)

        return torch.clamp(loss, 0, 10)

    @property
    def name(self):
        return "graph_smoothing"

    @property
    def requires_rbm_params(self):
        return False


class UnifiedLossManager:
    """
    Unified loss manager with generic getter method, automatic loss scaling, and support for multiple loss types.
    Enhanced to include new losses and dynamic weighting.
    """

    def __init__(self):
        self.losses = {
            "reconstruction": ReconstructionLoss(),
            "normalized_reconstruction": NormalizedReconstructionLoss(),
            "contrastive_divergence": ContrastiveDivergenceLoss(),
            "cd": ContrastiveDivergenceLoss(),  # Alias
            "free_energy": FreeEnergyLoss(),
            "kl_divergence": KLDivergenceLoss(),
            "cross_entropy": CrossEntropyLoss(),
            "huber": HuberLoss(),
            "supervised_contrastive": SupervisedContrastiveLoss(),
            "focal_classification": FocalClassificationLoss(),
            "graph_smoothing": GraphSmoothingLoss(),
        }

        # Loss scaling factors for normalization (empirically determined)
        self.loss_scales = {
            "reconstruction": 1.0,
            "normalized_reconstruction": 1.0,
            "contrastive_divergence": 0.1,
            "cd": 0.1,
            "free_energy": 0.1,
            "kl_divergence": 1.0,
            "cross_entropy": 1.0,
            "huber": 1.0,
            "supervised_contrastive": 1.0,
            "focal_classification": 1.0,
            "graph_smoothing": 0.1,  # Smaller scale as it's a regularizer
        }

        self.default_loss = "normalized_reconstruction"
        self.default_combo = ["normalized_reconstruction", "supervised_contrastive"]  # Default for multi-loss

    def register_loss(self, name, loss_instance, scale=1.0):
        """Register a custom loss function"""
        if not isinstance(loss_instance, BaseLoss):
            raise ValueError("Loss must inherit from BaseLoss")
        self.losses[name] = loss_instance
        self.loss_scales[name] = scale

    def get_loss(self, loss_type, v_reconstructed=None, node_pairs=None,
                 rbm_weights=None, visible_bias=None, hidden_bias=None,
                 apply_scaling=True, dynamic_weighting=True, **kwargs):
        """
        Generic getter method for all loss types, supporting single or multiple loss types.

        Args:
            loss_type: str or list of str - Name(s) of the loss function(s)
            v_reconstructed: [H, E, 2*d_out] - Reconstructed visible states (for reconstruction losses)
            node_pairs: [H, E, 2*d_out] - Original node pairs (for reconstruction losses)
            rbm_weights: [H, 2*d_out, d_out] - RBM weights (if needed)
            visible_bias: [H, 2*d_out] - Visible bias (if needed)
            hidden_bias: [H, d_out] - Hidden bias (if needed)
            apply_scaling: bool - Whether to apply automatic loss scaling
            dynamic_weighting: bool - Whether to apply dynamic loss weighting based on loss magnitudes
            **kwargs: Additional parameters (e.g., labels, edge_index, out, targets)

        Returns:
            torch.Tensor: Computed loss value (sum or weighted sum if multiple losses)
        """
        # Handle single loss or multiple losses
        if isinstance(loss_type, str):
            loss_types = [loss_type]
        elif isinstance(loss_type, list):
            loss_types = loss_type
        else:
            print(f"Invalid loss_type '{loss_type}', using default: {self.default_loss}")
            loss_types = [self.default_loss]

        total_loss = 0.0
        loss_values = []
        weights = []

        for lt in loss_types:
            if lt not in self.losses:
                print(f"Unknown loss type '{lt}', using default: {self.default_loss}")
                lt = self.default_loss

            loss_fn = self.losses[lt]

            try:
                # Check if RBM parameters are required
                if loss_fn.requires_rbm_params and any(x is None for x in [rbm_weights, visible_bias, hidden_bias]):
                    print(f"Loss '{lt}' requires RBM parameters, falling back to reconstruction")
                    loss_fn = self.losses["reconstruction"]
                    lt = "reconstruction"

                # Compute loss
                loss_value = loss_fn.compute(
                    v_reconstructed=v_reconstructed,
                    node_pairs=node_pairs,
                    rbm_weights=rbm_weights,
                    visible_bias=visible_bias,
                    hidden_bias=hidden_bias,
                    **kwargs
                )

                # Apply scaling if requested
                if apply_scaling:
                    scale_factor = self.loss_scales.get(lt, 1.0)
                    loss_value = scale_factor * loss_value

                loss_values.append(loss_value)
                weights.append(1.0)  # Default equal weight

            except Exception as e:
                print(f"Error computing loss '{lt}': {e}")
                # Fallback to reconstruction
                fallback_loss = self.losses["reconstruction"].compute(v_reconstructed, node_pairs)
                loss_values.append(fallback_loss)
                weights.append(1.0)

        # Dynamic weighting: adjust weights based on inverse loss magnitudes to balance
        if dynamic_weighting and len(loss_values) > 1:
            loss_magnitudes = [lv.item() + 1e-10 for lv in loss_values]  # Avoid division by zero
            total_magnitude = sum(loss_magnitudes)
            weights = [total_magnitude / (m * len(loss_magnitudes)) for m in loss_magnitudes]

        # Combine losses
        for lv, w in zip(loss_values, weights):
            total_loss += w * lv

        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print("NaN or Inf in total loss, returning fallback")
            return self.losses["reconstruction"].compute(v_reconstructed, node_pairs)

        return total_loss

    def get_available_losses(self):
        """Get list of available loss functions"""
        return list(self.losses.keys())

    def get_loss_info(self, loss_type):
        """Get information about a specific loss function"""
        if loss_type not in self.losses:
            return None

        loss_fn = self.losses[loss_type]
        return {
            "name": loss_fn.name,
            "requires_rbm_params": loss_fn.requires_rbm_params,
            "scale_factor": self.loss_scales.get(loss_type, 1.0)
        }


# Global loss manager instance
loss_manager = UnifiedLossManager()

