# =======================================================================
# Wrapper for Prior Methods: DUE, SNGP, and DUL
# 
# DUE and SNGP adapted from: https://github.com/y0ast/DUE
# DUL adapted from: https://github.com/yookoon/density_uncertainty_layers
# =======================================================================

# --- General imports ---
import numpy as np
import math

# --- PyTorch and Torch utilities ---
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.utils.data import DataLoader, TensorDataset

# --- Ignite (training and evaluation) ---
from ignite.engine import Engine, Events
from ignite.metrics import Loss, Average
from ignite.contrib.handlers import ProgressBar

# --- DUE-specific imports ---
from due.dkl import DKL, GP, initial_values
from due.sngp import Laplace
from due.fc_resnet import FCResNet

# --- GPyTorch (used by DUE) ---
import gpytorch
from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import GaussianLikelihood, SoftmaxLikelihood, BernoulliLikelihood

# --- DUL-specific imports ---
from density_uncertainty_layers.network import DensityModel, DensityLinear, BayesianModel, DensityConv2d
from density_uncertainty_layers.utils import *

# ====================================
# TRAINING ENGINE (Used by all models)
# ====================================

def _train_model(model, dl_train, dl_test, loss_fn, optimizer, epochs, likelihood=None, classification=False):
    """
    General training loop for uncertainty models.

    Args:
        model: PyTorch model to train.
        dl_train: DataLoader for training data.
        dl_test: DataLoader for test data.
        loss_fn: Loss function to minimize.
        optimizer: Optimizer (e.g., Adam).
        epochs: Number of training epochs.
        likelihood: GPyTorch likelihood (used in DUE).
        classification: Boolean flag for classification problems.
    """
    def step(engine, batch):
        # One training step
        model.train()
        if likelihood is not None:
            likelihood.train()
        optimizer.zero_grad()
        x, y = batch
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.squeeze().cuda()

        output = model(x)
        if isinstance(output, (tuple, list)):
            mean, variance = output
        else:
            mean = output
            variance = None

        if check_2d_tensor(mean):
            y = y.unsqueeze(1)
            if torch.cuda.is_available():
                y = y.cuda()

        if likelihood is not None:
            loss = loss_fn(output, y).mean()
        else:
            loss = loss_fn(mean, y).mean()

        loss.backward()
        optimizer.step()
        return loss.item()

    def eval_step(engine, batch):
        # One evaluation step
        model.eval()
        if likelihood is not None:
            likelihood.eval()
        x, y = batch
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        output = model(x)
        if isinstance(output, (tuple, list)):
            output = output[0]
        if check_2d_tensor(output) and output.size(1) == 1 and y.dim() == 1:
            y = y.unsqueeze(1)
        return output, y

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    # Reset SNGP precision matrix each epoch
    @trainer.on(Events.EPOCH_STARTED)
    def reset_precision(engine):
        if hasattr(model, "reset_precision_matrix"):
            model.reset_precision_matrix()

    train_metric = Average()
    train_metric.attach(trainer, "loss")

    pbar = ProgressBar()
    pbar.attach(trainer)

    # Evaluation metric
    if likelihood is not None:
        eval_metric = Loss(lambda y_pred, y: -likelihood.expected_log_prob(y, y_pred).mean())
    else:
        if not classification:
            def eval_loss(y_pred, y):
                pred = y_pred[0] if isinstance(y_pred, (tuple, list)) else y_pred
                if pred.dim() == 2 and pred.size(1) == 1:
                    pred = pred.squeeze(1)
                return F.mse_loss(pred, y)
            eval_metric = Loss(eval_loss)
        else:
            eval_metric = Loss(F.binary_cross_entropy_with_logits)
    eval_metric.attach(evaluator, "loss")

    # Logging
    @trainer.on(Events.EPOCH_COMPLETED(every=int(epochs / 10) + 1))
    def log_results(trainer):
        evaluator.run(dl_test)
        train_loss = trainer.state.metrics.get('loss')
        test_loss = evaluator.state.metrics.get('loss')
        if train_loss is None or test_loss is None:
            print(f"Epoch {trainer.state.epoch} | Metrics not available yet.")
        else:
            print(f"Epoch {trainer.state.epoch} | Test Loss: {test_loss:.4f} | Training Loss: {train_loss:.4f}")

    trainer.run(dl_train, max_epochs=epochs)
    model.eval()
    if likelihood is not None:
        likelihood.eval()
    return model

def check_2d_tensor(y_pred):
    """
    Check if tensor is 2D with second dim=1
    """
    if not isinstance(y_pred, torch.Tensor):
        return np.False_
    if y_pred.dim() == 2 and y_pred.shape[1] == 1:
        return True
    return False

# ====================================
# METHOD: DUE
# ====================================

def train_due(X_train, y_train, X_test, y_test, feature_extractor,
              num_outputs=1, n_inducing_points=20, kernel="RBF", lr=1e-3, batch_size=100, steps=1000, classification=False, multiclass=False):
    """
    Train model using Deterministic Uncertainty Estimation (DUE).
    """
    # Prepare dataset
    if y_train.shape[1]==1:
        y_train = y_train.squeeze()
        y_test = y_test.squeeze()

    if classification and multiclass:
        y_train_tensor = torch.from_numpy(y_train).long()
        y_test_tensor = torch.from_numpy(y_test).long()
    else:
        y_train_tensor = torch.from_numpy(y_train).float()
        y_test_tensor = torch.from_numpy(y_test).float()

    ds_train = TensorDataset(torch.from_numpy(X_train).float(), y_train_tensor)
    ds_test = TensorDataset(torch.from_numpy(X_test).float(), y_test_tensor)

    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)
    dl_test = DataLoader(ds_test, batch_size=512, shuffle=False)

    epochs = int(steps // len(dl_train)) 

    # Initialize GP and DKL model
    initial_inducing_points, initial_lengthscale = initial_values(ds_train, feature_extractor, n_inducing_points)
    gp = GP(num_outputs=num_outputs, initial_lengthscale=initial_lengthscale,
            initial_inducing_points=initial_inducing_points, kernel=kernel)
    model = DKL(feature_extractor, gp)

    # Choose appropriate likelihood
    if multiclass:
        likelihood = SoftmaxLikelihood(num_classes=num_outputs, mixing_weights=False)
    elif classification:
        likelihood = BernoulliLikelihood()
    else:
        likelihood = GaussianLikelihood()

    if torch.cuda.is_available():
        model = model.cuda()
        likelihood = likelihood.cuda()

    # Define ELBO loss
    elbo_fn = VariationalELBO(likelihood, model.gp, num_data=len(ds_train))
    loss_fn = lambda x, y: -elbo_fn(x, y)

    # Optimizer
    parameters = [{"params": model.parameters(), "lr": lr}, {"params": likelihood.parameters(), "lr": lr}]
    optimizer = torch.optim.Adam(parameters)

    return _train_model(model, dl_train, dl_test, loss_fn, optimizer, epochs, likelihood), likelihood

# ====================================
# METHOD: SNGP
# ====================================

def train_sngp(X_train, y_train, X_test, y_test, feature_extractor, representation_dim, classification, binary,
               num_gp_features=128, num_random_features=1024, normalize_gp_features=True,
               feature_scale=2, ridge_penalty=1, lr=1e-3, batch_size=5000, steps=100, num_outputs=1):
    """
    Train model using Spectral-normalized Gaussian Process (SNGP).
    """
    if y_train.shape[1]==1:
        y_train = y_train.squeeze()
        y_test = y_test.squeeze()

    # Choose appropriate dtype
    if classification and not binary:
        y_train_tensor = torch.from_numpy(y_train).long()
        y_test_tensor = torch.from_numpy(y_test).long()
    else:
        y_train_tensor = torch.from_numpy(y_train).float()
        y_test_tensor = torch.from_numpy(y_test).float()

    ds_train = TensorDataset(torch.from_numpy(X_train).float(), y_train_tensor)
    ds_test = TensorDataset(torch.from_numpy(X_test).float(), y_test_tensor)

    dl_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)
    dl_test = DataLoader(ds_test, batch_size=512, shuffle=False)

    epochs = int(steps // len(dl_train))

    model = Laplace(
        feature_extractor,
        representation_dim,
        num_gp_features,
        normalize_gp_features,
        num_random_features,
        num_outputs=num_outputs,
        num_data=len(ds_train),
        train_batch_size=batch_size,
        ridge_penalty=ridge_penalty,
        feature_scale=feature_scale,
    )

    if classification and binary:
        loss_fn = F.binary_cross_entropy_with_logits
    elif classification:
        loss_fn = F.cross_entropy
    else:
        loss_fn = F.mse_loss

    if torch.cuda.is_available():
        model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return _train_model(model, dl_train, dl_test, loss_fn, optimizer, epochs, classification=classification)

# ====================================
# METHOD: DUL
# ====================================

def density_uncertainty(X_train, Y_train, model, criterion=nn.BCELoss(), learning_rate=1e-2, epochs=50, batch_size=50, warmup_epochs=0, kl_mult=1.0, kl_beta=1.0, ll_scale=0.01):
    """
    Train model using Density Uncertainty Layers (DUL).
    """
    # Prepare dataset
    if Y_train.shape[1]==1:
        Y_train = Y_train.squeeze()
    X_train, Y_train = torch.from_numpy(X_train).float(), torch.from_numpy(Y_train).float()
    train_dataset = TensorDataset(X_train, Y_train)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=False)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = model.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    scheduler = LambdaLR(optimizer, lambda step: min(1.0, step / (warmup_epochs * len(train_loader))) if step < warmup_epochs * len(train_loader) else 1.0)

    for epoch in range(epochs):
        net.train()
        train_loss, train_kl, train_ll, correct, n_train = 0, 0, 0, 0, 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = len(inputs)
            n_train += batch_size

            mean, variance = net(inputs)

            targets = targets.squeeze()
            if isinstance(criterion, nn.CrossEntropyLoss):
                targets = targets.long()
            loss = criterion(mean, targets)

            if isinstance(net, BayesianModel):
                if epoch + batch_idx == 0:
                    print("BayesianModel")
                kl_div = net.kl_div() / n_train
                loss += kl_beta * kl_mult * kl_div
                train_kl += kl_div * batch_size
                kl_mult = min(1.0, kl_mult + 1)

            if isinstance(net, DensityModel):
                if epoch + batch_idx == 0:
                    print("DensityModel")
                loglikelihood = net.loglikelihood()
                loss -= loglikelihood * ll_scale
                train_ll += loglikelihood * batch_size

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if isinstance(criterion, nn.CrossEntropyLoss):
                predicted = mean.max(dim=1)[1]
            else:
                predicted = mean.max(dim=1)[1] if mean.dim() > 1 else (mean > 0.5).float()
            correct += predicted.eq(targets).sum().item()

            train_kl = train_kl / n_train
            train_ll = train_ll / n_train
            train_accuracy = 100. * correct / n_train

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

    return net