from __future__ import annotations
import torch
import torch.optim as optim
from torch import nn
from typing import Dict, Any, Tuple, Callable, Optional, Union
import os
import gc
import typing as ty

class Lambda(nn.Module):
    def __init__(self, f: ty.Callable) -> None:
        super().__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)

def make_class_weight(y: torch.Tensor,
                      num_classes: int,
                      device,
                      max_weight: float = 10.0) -> torch.Tensor:
    y_flat = y.view(-1).long()
    counts = torch.bincount(y_flat, minlength=num_classes).float()
    weights = counts.sum() / (counts + 1e-6)
    weights = weights / weights.mean()
    weights = torch.clamp(weights, max=max_weight)

    return weights.to(device)

def make_optimizer(
    optimizer: str,
    parameter_groups,
    lr: float,
    weight_decay: float,
) -> optim.Optimizer:

    Optimizer = {
        'adam': optim.Adam,
        'adamw': optim.AdamW,
        'sgd': optim.SGD,
    }[optimizer]  
    momentum = (0.9,) if Optimizer is optim.SGD else ()
    return Optimizer(parameter_groups, lr, *momentum, weight_decay=weight_decay)

def early_stop_training_loop(
    model: nn.Module,
    train_loader,
    val_loader,
    optimizer: optim.Optimizer,
    loss_fn: Callable,
    max_epochs: int = 1000,
    patience: int = 20,
    min_delta: float = 1e-4,
    min_epochs: int = 400,
    verbose: bool = True,
    checkpoint_path: str = "best_model.pt",
    max_grad_norm: float = float('inf'),
    device: Optional[torch.device] = None,
    **loss_fn_kwargs
) -> Tuple[nn.Module, Dict[str, Any]]:

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = model.to(device)

    if os.path.exists(checkpoint_path):
        os.remove(checkpoint_path)
        if verbose:
            print(f"Deleted old checkpoint: {checkpoint_path}")
    
    best_val_loss = float("inf")
    epochs_no_improve = 0
    checkpoint_saved = False
    
    logs: Dict[str, list] = {"train_loss": [], "val_loss": []}
    
    for ep in range(max_epochs):
        model.train()
        running_loss = 0.0
        
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()

            loss = loss_fn(model, xb, yb, **loss_fn_kwargs)
            loss.backward()

            if max_grad_norm < float('inf'):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
            
            optimizer.step()
            running_loss += loss.item() * xb.size(0)
        
        train_loss = running_loss / len(train_loader.dataset)
        logs["train_loss"].append(train_loss)

        model.eval()
        val_running = 0.0
        
        with torch.no_grad():
            for xb, yb in val_loader:
                xb, yb = xb.to(device), yb.to(device)
                val_loss = loss_fn(model, xb, yb, **loss_fn_kwargs)
                val_running += val_loss.item() * xb.size(0)
        
        val_loss = val_running / len(val_loader.dataset)
        logs["val_loss"].append(val_loss)

        if val_loss < best_val_loss - min_delta:
            best_val_loss = val_loss
            torch.save(model.state_dict(), checkpoint_path)
            checkpoint_saved = True
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if ep >= min_epochs and epochs_no_improve >= patience:
                if verbose:
                    print(f"Early stop @ {ep}. Best val = {best_val_loss:.6f}")
                break

        torch.cuda.empty_cache()
        gc.collect()

    if checkpoint_saved:
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    
    model.eval()
    return model, logs

def sbl_loss_wrapper(model, x, y, class_weight=None):
    return model.loss(x, y, class_weight)


def standard_loss_wrapper(model, x, y, criterion, class_weight=None):
    outputs = model(x)
    if class_weight is not None and hasattr(criterion, 'weight'):
        criterion.weight = class_weight
    return criterion(outputs, y)

def mlp_loss_wrapper(model, x, y, criterion, class_weight=None):
    if isinstance(x, tuple):
        x_num, x_cat = x
        outputs = model(x_num, x_cat)
    else:
        outputs = model(x, None)

    if class_weight is not None:
        if hasattr(criterion, 'weight'):
            criterion.weight = class_weight
        elif isinstance(criterion, torch.nn.BCEWithLogitsLoss):
            pos_weight = class_weight[1] / class_weight[0] if len(class_weight) > 1 else None
            criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    if outputs.dim() > 1 and outputs.shape[1] > 1:
        target = y.long().squeeze()
        return criterion(outputs, target)
    else:
        target = y.float().squeeze()
        return criterion(outputs.squeeze(), target)


def create_mlp_loss_wrapper(criterion_class, **criterion_kwargs):

    def wrapper(model, x, y, class_weight=None):

        if class_weight is not None and hasattr(criterion_class, '__init__'):
            criterion = criterion_class(weight=class_weight, **criterion_kwargs)
        else:
            criterion = criterion_class(**criterion_kwargs)       
        return mlp_loss_wrapper(model, x, y, criterion, class_weight)   
    return wrapper
