from typing import Dict, Optional

import torch.nn as nn
import torch


class Criterion:
    def __init__(
        self, criterion_name: str, regularization: Optional[Dict[str, float]] = None
    ):
        self.criterion_name = criterion_name
        self.regularization = regularization if regularization else {}
        self.criterion = self.get_criterion()

    def get_criterion(self):
        criteria = {
            "MSE": nn.MSELoss,
            "CrossEntropy": nn.CrossEntropyLoss,
            "BCE": nn.BCELoss,
            "BCEWithLogits": nn.BCEWithLogitsLoss,
            "NLL": nn.NLLLoss,
            "L1": nn.L1Loss,
            "SmoothL1": nn.SmoothL1Loss,
            "HingeEmbedding": nn.HingeEmbeddingLoss,
            "MarginRanking": nn.MarginRankingLoss,
            "CosineEmbedding": nn.CosineEmbeddingLoss,
            "MultiLabelMargin": nn.MultiLabelMarginLoss,
        }
        if self.criterion_name not in criteria:
            raise ValueError(f"Unsupported criterion: {self.criterion_name}")
        return criteria[self.criterion_name]()

    def set_from_model(self, model: nn.Module):
        self.model = model

    def compute_loss(self, outputs, targets):
        if self.model is None:
            raise ValueError(
                "Model has not been set. Please call set_from_model() before computing the loss."
            )
        
        if isinstance(self.criterion, nn.CrossEntropyLoss):
            if targets.dim() == 2 and targets.shape[1] > 1:
                targets = targets.argmax(dim=1)
            else:
                targets = targets.view(-1)
            targets = targets.to(dtype=torch.long)
        loss = self.criterion(outputs, targets)
        if "l1" in self.regularization:
            l1_lambda = self.regularization["l1"]
            l1_norm = sum(param.abs().sum() for param in self.model.parameters())
            loss += l1_lambda * l1_norm

        elif "l2" in self.regularization:
            l2_lambda = self.regularization["l2"]
            l2_norm = sum(param.pow(2).sum() for param in self.model.parameters())
            loss += l2_lambda * l2_norm

        elif "elasticnet" in self.regularization:
            alpha = self.regularization["elasticnet"].get("alpha")
            ratio = self.regularization["elasticnet"].get("ratio")
            l1_norm = sum(param.abs().sum() for param in self.model.parameters())
            l2_norm = sum(param.pow(2).sum() for param in self.model.parameters())
            elastic_loss = alpha * (ratio * l1_norm + (1 - ratio) * l2_norm)
            loss += elastic_loss

        elif "weight_decay" in self.regularization:
            wd_lambda = self.regularization["weight_decay"]
            wd_norm = sum(param.pow(2).sum() for param in self.model.parameters())
            loss += wd_lambda * wd_norm

        return loss

    def get_regularization_type(self):
        if self.regularization:
            return list(self.regularization.keys())[0]
        return None