import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.preprocessing import KBinsDiscretizer
from scipy.stats import gaussian_kde
from torch.distributions import MultivariateNormal as MVN
from torch.nn.modules.loss import _Loss


def _gai_loss_md(pred, target, gmm, noise_var):
    """Multi-dimensional GAI Loss computation."""
    device = pred.device
    
    gmm_means = gmm['means'].to(device)
    gmm_vars = gmm['variances'].to(device)
    gmm_weights = gmm['weights'].to(device)
    
    if pred.shape[-1] == 1 and target.dim() == pred.dim() and target.shape[-1] != 1:
        target = target.unsqueeze(-1) 
    elif pred.shape[-1] > 1 and target.dim() == pred.dim() - 1:
        target = target.unsqueeze(-1)
        
    if target.shape != pred.shape:
        if target.numel() == pred.numel() and pred.shape[-1] == 1:
            target = target.reshape(pred.shape)
        else:
            raise ValueError(f"GAI loss: Target shape {target.shape} incompatible with pred shape {pred.shape}")

    I = torch.eye(pred.shape[-1], device=device)
    mse_term = -MVN(pred, noise_var * I).log_prob(target)
    
    if gmm_vars.dim() == 2 and gmm_vars.shape[1] == 1:
        gmm_vars = gmm_vars.unsqueeze(-1)
    elif gmm_vars.dim() == 1:
        gmm_vars = gmm_vars.unsqueeze(-1).unsqueeze(-1)
    
    balancing_term = MVN(gmm_means, gmm_vars + noise_var * I).log_prob(pred.unsqueeze(1)) + gmm_weights.log()
    balancing_term = torch.logsumexp(balancing_term, dim=1)
    
    loss = mse_term + balancing_term
    loss = loss * (2 * noise_var).detach() 
    return loss.mean()


def _bmc_loss_md(pred, target, noise_var):
    """Multi-dimensional BMC Loss computation."""
    device = pred.device
    I = torch.eye(pred.shape[-1], device=device)
    
    if target.shape != pred.shape:
        if target.numel() == pred.numel() and pred.shape[-1] == 1:
            target = target.reshape(pred.shape)
        else:
            raise ValueError(f"BMC loss: Target shape {target.shape} incompatible with pred shape {pred.shape}")
    
    logits = MVN(pred.unsqueeze(1), noise_var * I).log_prob(target.unsqueeze(0))
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0], device=device))
    loss = loss * (2 * noise_var).detach()
    return loss


class BMCLossMD(_Loss):
    """Balanced MSE via Contrastive learning (BMC) Loss."""
    
    def __init__(self, init_noise_sigma):
        super(BMCLossMD, self).__init__()
        self.noise_sigma = torch.nn.Parameter(torch.tensor(float(init_noise_sigma)))
        print(f"BMCLossMD initialized with noise sigma: {init_noise_sigma:.4f}")

    def forward(self, pred, target):
        noise_var = F.softplus(self.noise_sigma)**2 + 1e-6
        loss = _bmc_loss_md(pred, target, noise_var)
        return loss


class GAILossMD(_Loss):
    """Gaussian Mixture Model based Balanced MSE Loss (GAI version)."""
    
    def __init__(self, gmm_dict, init_noise_sigma):
        super(GAILossMD, self).__init__()
        if gmm_dict is None:
            raise ValueError("GAILossMD requires GMM dictionary with 'means', 'weights', 'variances'.")
        
        self.gmm = {k: torch.tensor(v, dtype=torch.float32) for k, v in gmm_dict.items()}
        
        if self.gmm['means'].dim() == 1: 
            self.gmm['means'] = self.gmm['means'].unsqueeze(-1)
        if self.gmm['variances'].dim() == 1: 
            self.gmm['variances'] = self.gmm['variances'].unsqueeze(-1).unsqueeze(-1)
        elif self.gmm['variances'].dim() == 2 and self.gmm['variances'].shape[1] == 1: 
            self.gmm['variances'] = self.gmm['variances'].unsqueeze(-1)

        self.noise_sigma = torch.nn.Parameter(torch.tensor(float(init_noise_sigma))) 
        print(f"GAILossMD initialized with noise sigma: {init_noise_sigma:.4f}")

    def forward(self, pred, target):
        noise_var = F.softplus(self.noise_sigma)**2 + 1e-6 
        loss = _gai_loss_md(pred, target, self.gmm, noise_var)
        return loss


class WeightedL1Loss(_Loss):
    """Applies precomputed weights to L1 loss."""
    
    def __init__(self, weights_tensor):
        super(WeightedL1Loss, self).__init__()
        if weights_tensor is None:
            print("Warning: WeightedL1Loss initialized with None weights. Behaves like standard L1Loss.")
            self.weights = None
        else:
            self.register_buffer('weights', weights_tensor.cpu()) 
        print(f"WeightedL1Loss initialized. Weights shape: {self.weights.shape if self.weights is not None else 'None'}")

    def forward(self, inputs, targets, batch_indices=None):
        loss_unreduced = F.l1_loss(inputs, targets, reduction='none')

        if self.weights is not None:
            if batch_indices is None:
                raise ValueError("WeightedL1Loss requires batch_indices when using precomputed weights.")

            current_device = inputs.device
            weights_on_device = self.weights.to(current_device)
            indices_on_device = batch_indices.to(current_device)
            batch_weights = weights_on_device[indices_on_device]
            
            if loss_unreduced.dim() > batch_weights.dim() and loss_unreduced.shape[0] == batch_weights.shape[0]:
                batch_weights = batch_weights.view(-1, *([1]*(loss_unreduced.dim()-1))) 
                 
            loss = loss_unreduced * batch_weights
        else:
            loss = loss_unreduced

        return loss.mean()


class WeightedMSELoss(_Loss):
    """Applies precomputed weights to MSE loss."""
    
    def __init__(self, weights_tensor):
        super(WeightedMSELoss, self).__init__()
        if weights_tensor is None:
            print("Warning: WeightedMSELoss initialized with None weights. Behaves like standard MSELoss.")
            self.weights = None
        else:
            self.register_buffer('weights', weights_tensor.cpu())
        print(f"WeightedMSELoss initialized. Weights shape: {self.weights.shape if self.weights is not None else 'None'}")

    def forward(self, inputs, targets, batch_indices=None):
        loss_unreduced = F.mse_loss(inputs, targets, reduction='none')

        if self.weights is not None:
            if batch_indices is None:
                raise ValueError("WeightedMSELoss requires batch_indices when using precomputed weights.")

            current_device = inputs.device
            weights_on_device = self.weights.to(current_device)
            indices_on_device = batch_indices.to(current_device)
            batch_weights = weights_on_device[indices_on_device]
            
            if loss_unreduced.dim() > batch_weights.dim() and loss_unreduced.shape[0] == batch_weights.shape[0]:
                batch_weights = batch_weights.view(-1, *([1]*(loss_unreduced.dim()-1)))
                 
            loss = loss_unreduced * batch_weights
        else:
            loss = loss_unreduced

        return loss.mean()


class BalancedMSELoss_DEPRECATED(nn.Module):
    """Balanced MSE loss using inverse frequency weights."""
    
    def __init__(self, y_train, n_bins=10, strategy='quantile', gamma=1.0):
        super().__init__()
        self.gamma = gamma
        self.n_bins = n_bins
        self.strategy = strategy

        if y_train is None:
             print("Warning: y_train is None during BalancedMSELoss initialization. Weights will be uniform.")
             self.register_buffer('weights_per_bin', torch.ones(n_bins))
             self.discretizer = None
             return

        y_train_arr = y_train.reshape(-1, 1)
        self.discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy=strategy, subsample=None)

        try:
            bin_assignments = self.discretizer.fit_transform(y_train_arr).flatten().astype(int)
        except ValueError as e:
            print(f"Warning: KBinsDiscretizer failed with strategy='{strategy}'. Trying 'uniform'. Error: {e}")
            self.discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform', subsample=None)
            bin_assignments = self.discretizer.fit_transform(y_train_arr).flatten().astype(int)

        bin_counts = np.bincount(bin_assignments, minlength=n_bins)
        bin_counts = bin_counts + 1e-6

        weights = (1.0 / bin_counts) ** self.gamma
        weights = weights / np.sum(weights) * n_bins
        print(f"BalancedMSELoss: Computed bin weights: {weights}")

        self.register_buffer('weights_per_bin', torch.tensor(weights, dtype=torch.float32))

    def forward(self, y_pred, y_true):
        if self.discretizer is None:
             return F.mse_loss(y_pred, y_true)

        y_true_cpu = y_true.detach().cpu().numpy().reshape(-1, 1)

        with torch.no_grad():
             y_true_clipped = np.clip(y_true_cpu, self.discretizer.bin_edges_[0][0], self.discretizer.bin_edges_[0][-1])
             try:
                  target_bins = self.discretizer.transform(y_true_clipped).flatten().astype(int)
             except ValueError:
                 print("Warning: Values out of discretizer range found during loss computation.")
                 target_bins = self.discretizer.transform(y_true_clipped).flatten().astype(int)

             target_bins_tensor = torch.tensor(target_bins, dtype=torch.long, device=y_pred.device)
             sample_weights = self.weights_per_bin[target_bins_tensor]

        squared_errors = F.mse_loss(y_pred, y_true, reduction='none')
        weighted_loss = (squared_errors * sample_weights.unsqueeze(-1)).mean()
        return weighted_loss


class LDSLoss_DEPRECATED(nn.Module):
    """Label Distribution Smoothing (LDS) based loss."""
    
    def __init__(self, y_train, n_bins=50, kernel='gaussian', ks=5, sigma=2):
        super().__init__()
        self.kernel = kernel
        self.ks = ks
        self.sigma = sigma
        self.n_bins = n_bins

        if y_train is None:
             print("Warning: y_train is None during LDSLoss initialization. Weights will be uniform.")
             self.register_buffer('lds_weights_per_bin', torch.ones(n_bins))
             self.discretizer = None
             return

        y_train_arr = y_train.reshape(-1,)

        self.discretizer = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform', subsample=None)
        bin_assignments = self.discretizer.fit_transform(y_train_arr.reshape(-1, 1)).flatten().astype(int)
        empirical_density = np.bincount(bin_assignments, minlength=n_bins).astype(float)
        empirical_density /= empirical_density.sum()

        kernel_weights = self._get_kernel_window()
        effective_density = np.convolve(empirical_density, kernel_weights, mode='same')
        effective_density = np.maximum(effective_density, 1e-6)

        weights = 1.0 / effective_density
        weights = weights / np.sum(weights) * n_bins
        print(f"LDSLoss: Computed bin weights (first 10): {weights[:10]}")

        self.register_buffer('lds_weights_per_bin', torch.tensor(weights, dtype=torch.float32))

    def _get_kernel_window(self):
        if self.kernel == 'gaussian':
            half_ks = (self.ks - 1) // 2
            x = np.linspace(-half_ks, half_ks, self.ks)
            kernel = np.exp(-x**2 / (2 * self.sigma**2))
        elif self.kernel == 'triang':
             kernel = 1 - np.abs(np.linspace(-1, 1, self.ks))
        else:
            half_ks = (self.ks - 1) // 2
            x = np.linspace(-half_ks, half_ks, self.ks)
            kernel = np.exp(-x**2 / (2 * self.sigma**2))

        kernel /= np.sum(kernel)
        return kernel

    def forward(self, y_pred, y_true):
        if self.discretizer is None:
             return F.mse_loss(y_pred, y_true)

        y_true_cpu = y_true.detach().cpu().numpy().reshape(-1, 1)

        with torch.no_grad():
            y_true_clipped = np.clip(y_true_cpu, self.discretizer.bin_edges_[0][0], self.discretizer.bin_edges_[0][-1])
            try:
                target_bins = self.discretizer.transform(y_true_clipped).flatten().astype(int)
            except ValueError:
                 print("Warning: Values out of discretizer range found during LDS loss computation.")
                 target_bins = self.discretizer.transform(y_true_clipped).flatten().astype(int)

            target_bins_tensor = torch.tensor(target_bins, dtype=torch.long, device=y_pred.device)
            sample_weights = self.lds_weights_per_bin[target_bins_tensor]

        squared_errors = F.mse_loss(y_pred, y_true, reduction='none')
        weighted_loss = (squared_errors * sample_weights.unsqueeze(-1)).mean()
        return weighted_loss


def ConR_tabular(features, targets, preds, w=1.0, weights=1.0, t=0.07, e=0.01):
    """
    ConR (Contrastive Regularizer) for tabular data.
    
    Args:
        features: Normalized feature representations from the model [batch_size, feature_dim]
        targets: Ground truth labels [batch_size, 1]
        preds: Model predictions [batch_size, 1]
        w: Distance threshold for positive/negative pair definition
        weights: Sample weights for imbalanced learning
        t: Temperature parameter for contrastive learning
        e: Pushing power scale for negative samples
    """
    q = F.normalize(features, dim=1)
    k = F.normalize(features, dim=1)

    l_k = targets.flatten()[None, :]
    l_q = targets.flatten()[:, None]

    p_k = preds.flatten()[None, :]
    p_q = preds.flatten()[:, None]
    
    l_dist = torch.abs(l_q - l_k)
    p_dist = torch.abs(p_q - p_k)

    pos_i = l_dist.le(w)
    neg_i = ((~(l_dist.le(w))) * (p_dist.le(w)))

    for i in range(pos_i.shape[0]):
        pos_i[i][i] = 0

    prod = torch.einsum("nc,kc->nk", [q, k]) / t
    pos = prod * pos_i
    neg = prod * neg_i

    if isinstance(weights, (int, float)):
        weights = torch.ones(targets.shape[0], device=targets.device) * weights
    
    pushing_w = weights.unsqueeze(1) * torch.exp(l_dist * e)
    neg_exp_dot = (pushing_w * (torch.exp(neg)) * neg_i).sum(1)

    no_neg_flag = (neg_i).sum(1).bool()

    denom = pos_i.sum(1)
    denom = torch.clamp(denom, min=1)

    pos_exp = torch.exp(pos)
    pos_sum = pos_exp.sum(1)
    
    numerator = pos_exp
    denominator = (pos_sum + neg_exp_dot).unsqueeze(-1)
    
    ratio = torch.div(numerator, denominator + 1e-8)
    log_ratio = torch.log(ratio + 1e-8)
    
    loss = ((-log_ratio * pos_i).sum(1) / denom)
    loss = (weights * loss * no_neg_flag).mean()
    
    return loss


class ConRLoss(nn.Module):
    """ConR Loss for tabular data with imbalanced regression."""
    
    def __init__(self, w=1.0, t=0.07, e=0.01, alpha=1.0, mse_weight=1.0):
        super(ConRLoss, self).__init__()
        self.w = w
        self.t = t
        self.e = e
        self.alpha = alpha
        self.mse_weight = mse_weight
        print(f"ConRLoss initialized: w={w}, t={t}, e={e}, alpha={alpha}, mse_weight={mse_weight}")

    def forward(self, pred, target, features=None, weights=None):
        mse_loss = F.mse_loss(pred, target)
        
        if features is not None and len(target) > 1:
            conr_loss = ConR_tabular(features, target, pred, 
                                   w=self.w, weights=weights if weights is not None else 1.0, 
                                   t=self.t, e=self.e)
            total_loss = self.mse_weight * mse_loss + self.alpha * conr_loss
        else:
            total_loss = mse_loss
            
        return total_loss


import random

def stable_argsort(arr, dim=-1, descending=False):
    """PyTorch stable argsort (1.9+)."""
    if descending:
        return torch.argsort(-arr, dim=dim, stable=True)
    else:
        return torch.argsort(arr, dim=dim, stable=True)

def flipp(T, dim):
    """Flip tensor along dimension."""
    inv_idx = torch.arange(T.size(dim)-1, -1, -1, device=T.device)
    return T.index_select(dim, inv_idx)
    
def rank(seq):
    """Compute rank."""
    return stable_argsort(flipp(stable_argsort(seq), 1))

def rank_normalised(seq):
    """Compute normalized rank."""
    return (rank(seq) + 1).float() / seq.size()[1]

class TrueRanker(torch.autograd.Function):
    """Differentiable ranking operation autograd function."""
    
    @staticmethod
    def forward(ctx, sequence, lambda_val):
        rank = rank_normalised(sequence)
        ctx.lambda_val = lambda_val
        ctx.save_for_backward(sequence, rank)
        return rank

    @staticmethod
    def backward(ctx, grad_output):
        sequence, rank = ctx.saved_tensors
        assert grad_output.shape == rank.shape
        sequence_prime = sequence + ctx.lambda_val * grad_output
        rank_prime = rank_normalised(sequence_prime)
        gradient = -(rank - rank_prime) / (ctx.lambda_val + 1e-8)
        return gradient, None

def batchwise_ranking_regularizer(features, targets, lambda_val):
    """Memory-optimized RankSim loss function."""
    loss = 0
    
    batch_unique_targets = torch.unique(targets)
    if len(batch_unique_targets) < len(targets):
        sampled_indices = []
        for target in batch_unique_targets:
            candidates = (targets == target).nonzero()[:,0]
            if len(candidates) > 0:
                random_idx = torch.randint(0, len(candidates), (1,), device=candidates.device).item()
                sampled_indices.append(candidates[random_idx].item())
        if len(sampled_indices) > 1:
            x = features[sampled_indices]
            y = targets[sampled_indices]
        else:
            x = features
            y = targets
    else:
        x = features
        y = targets

    if len(y) < 2:
        return torch.tensor(0.0, device=features.device)

    x_flat = x.view(x.size(0), -1)
    x_norm = F.normalize(x_flat)
    xxt = torch.matmul(x_norm, x_norm.permute(1,0))

    for i in range(len(y)):
        y_diff = -torch.abs(y[i] - y)
        label_ranks = rank_normalised(y_diff.unsqueeze(0))
        
        feature_ranks = TrueRanker.apply(xxt[i].unsqueeze(dim=0), lambda_val)
        
        if feature_ranks.shape != label_ranks.shape:
            if feature_ranks.numel() == label_ranks.numel():
                label_ranks = label_ranks.view(feature_ranks.shape)
            else:
                feature_ranks = feature_ranks.flatten()
                label_ranks = label_ranks.flatten()
                if feature_ranks.numel() != label_ranks.numel():
                    min_size = min(feature_ranks.numel(), label_ranks.numel())
                    feature_ranks = feature_ranks[:min_size]
                    label_ranks = label_ranks[:min_size]
        
        loss += F.mse_loss(feature_ranks, label_ranks)
        
        del y_diff, label_ranks, feature_ranks

    return loss


class RankSimLoss(nn.Module):
    """RankSim Loss for tabular data."""
    
    def __init__(self, lambda_val=1.0, alpha=1.0):
        super(RankSimLoss, self).__init__()
        self.lambda_val = lambda_val
        self.alpha = alpha
        print(f"RankSimLoss initialized: lambda_val={lambda_val}, alpha={alpha}")

    def forward(self, pred, target, features=None):
        mse_loss = F.mse_loss(pred, target)
        
        if features is not None and len(target) > 1:
            ranking_loss = batchwise_ranking_regularizer(features, target, self.lambda_val)
            total_loss = mse_loss + self.alpha * ranking_loss
        else:
            total_loss = mse_loss
            
        return total_loss
