import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
import einops
from collections import Counter

def get_loss_module(problem, data_loader=None):
        if problem =='TUAB':
            return NoFussBCELoss(reduction='mean')
        elif problem == 'CHB-MIT':
            return FocalLoss()
        elif problem == 'TUEV' and data_loader is not None:
            class_weights = get_class_proportions(data_loader)
            return WeightedCELoss(reduction='mean', class_weights=class_weights)
        elif problem == 'TUEV':
            return DynamicWeightedCELoss(reduction='mean')
        return NoFussCrossEntropyLoss(reduction='mean')  # outputs loss for each batch

def l2_reg_loss(model):
    """Returns the squared L2 norm of output layer of given model"""

    for name, param in model.named_parameters():
        if name == 'output_layer.weight':
            return torch.sum(torch.square(param))

def get_class_proportions(data_loader, num_classes=6):
    """Calculate class proportions from a DataLoader"""
    all_labels = []
    
    for batch in data_loader:
        _, targets, _ = batch  # X, targets, IDs
        all_labels.extend(targets.cpu().numpy())
    
    # Count occurrences
    label_counts = Counter(all_labels)
    total_samples = len(all_labels)
    assert num_classes == len(label_counts)
    proportions = []
    for i in range(num_classes):
        proportions.append(total_samples / label_counts.get(i, 1) / num_classes)
    
    return torch.tensor(proportions)

class NoFussCrossEntropyLoss(nn.CrossEntropyLoss):
    def forward(self, inp, target):
        return F.cross_entropy(inp, target.long(), weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

class NoFussBCELoss(nn.BCEWithLogitsLoss):
    def forward(self, inp, target):
        return F.binary_cross_entropy_with_logits(inp.squeeze(), target.float(), weight=self.weight, reduction=self.reduction)

class WeightedCELoss(nn.Module):
    def __init__(self, class_weights=None, reduction='mean'):
        super().__init__()
        self.reduction = reduction
        if class_weights is not None:
            if isinstance(class_weights, (list, tuple)):
                class_weights = torch.tensor(class_weights)
            self.register_buffer('class_weights', class_weights)
        else:
            self.class_weights = None
    
    def forward(self, inp, target):
        return F.cross_entropy(inp, target, weight=self.class_weights, reduction=self.reduction)

class DynamicWeightedCELoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.reduction = reduction
    
    def forward(self, inp, target, num_classes=6):
        label_counts = torch.bincount(target, minlength=num_classes).float() + 1.
        total_samples = target.shape[0]
        for i in range(num_classes):
            label_counts[i] = total_samples / label_counts[i] / num_classes
        return F.cross_entropy(inp, target, weight=label_counts, reduction=self.reduction)


class FocalLoss(nn.Module):
    def __init__(self, reduction='mean', alpha=0.8, gamma=0.7):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, y_hat, y):
        y_hat = y_hat.view(-1, 1)
        y = y.view(-1, 1)
        p = torch.sigmoid(y_hat)
        loss = -self.alpha * (1 - p) ** self.gamma * y * torch.log(p) - (1 - self.alpha) * p**self.gamma * (
            1 - y
        ) * torch.log(1 - p)
        return loss.mean()
    
class MultiFocalLoss(nn.Module):
    def __init__(self, reduction='mean', alpha=0.8, gamma=0.7):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, outputs, targets):
        ce_loss = torch.nn.functional.cross_entropy(outputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        loss = (self.alpha * (1-pt)**self.gamma * ce_loss).mean()
        return loss