import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.utils.data import DataLoader, TensorDataset

from ignite.engine import Engine, Events
from ignite.metrics import Loss, Average
from ignite.contrib.handlers import ProgressBar

from due.dkl import DKL, GP, initial_values
from due.sngp import Laplace
from due.fc_resnet import FCResNet

import gpytorch
from gpytorch.mlls import VariationalELBO
from gpytorch.likelihoods import GaussianLikelihood, SoftmaxLikelihood

from due.dkl import DKL, GP, initial_values
from due.sngp import Laplace

from density_uncertainty_layers.network import DensityModel, DensityLinear, BayesianModel, DensityConv2d
from density_uncertainty_layers.utils import *


def _train_model(model, dl_train, dl_test, loss_fn, optimizer, epochs, likelihood=None, classification=False, dtype='float32'):
    """
    Train and evaluate models that may output either (mean, variance) or a single prediction.
    """
    def step(engine, batch):
        model.train()
        if likelihood is not None:
            likelihood.train()
        optimizer.zero_grad()
        x, y = batch
        if torch.cuda.is_available():
            x = x.cuda()
            y = y.squeeze().cuda()
        # Call model and check its output type.
        output = model(x)
        if isinstance(output, (tuple, list)):
            mean, variance = output
        else:
            mean = output
            variance = None
        if check_2d_tensor(mean):
            y = y.unsqueeze(1)
            if torch.cuda.is_available():
                y = y.cuda()
        if likelihood is not None:
            loss = loss_fn(output, y).mean()
        else:
            loss = loss_fn(mean, y).mean()
        loss.backward()
        optimizer.step()
        return loss.item()

    def eval_step(engine, batch):
        model.eval()
        if likelihood is not None:
            likelihood.eval()
        x, y = batch
        if torch.cuda.is_available():
            x, y = x.cuda(), y.cuda()
        output = model(x)
        # If the model returns a tuple or list, extract the first element.
        if isinstance(output, (tuple, list)):
            output = output[0]
        if check_2d_tensor(output) and output.size(1) == 1 and y.dim() == 1:
            y = y.unsqueeze(1)
        return output, y

    trainer = Engine(step)
    evaluator = Engine(eval_step)

    # Reset SNGP precision matrix at the start of each epoch if applicable.
    @trainer.on(Events.EPOCH_STARTED)
    def reset_precision(engine):
        if hasattr(model, "reset_precision_matrix"):
            model.reset_precision_matrix()

    train_metric = Average()
    train_metric.attach(trainer, "loss")

    pbar = ProgressBar()
    pbar.attach(trainer)

    if likelihood is not None:
        eval_metric = Loss(lambda y_pred, y: -likelihood.expected_log_prob(y, y_pred).mean())
    else:
        if not classification:
            def eval_loss(y_pred, y):
                # Extract the prediction (if model returns a tuple)
                pred = y_pred[0] if isinstance(y_pred, (tuple, list)) else y_pred
                # Squeeze out the extra dimension if needed
                if pred.dim() == 2 and pred.size(1) == 1:
                    pred = pred.squeeze(1)
                return F.mse_loss(pred, y)
            eval_metric = Loss(eval_loss)
        else:
            eval_metric = Loss(F.binary_cross_entropy_with_logits)
    eval_metric.attach(evaluator, "loss")

    @trainer.on(Events.EPOCH_COMPLETED(every=int(epochs / 10) + 1))
    def log_results(trainer):
        evaluator.run(dl_test)
        train_loss = trainer.state.metrics.get('loss')
        test_loss = evaluator.state.metrics.get('loss')
        if train_loss is None or test_loss is None:
            print(f"Epoch {trainer.state.epoch} | Metrics not available yet.")
        else:
            print(f"Epoch {trainer.state.epoch} | Test Loss: {test_loss:.4f} | Training Loss: {train_loss:.4f}")

    trainer.run(dl_train, max_epochs=epochs)
    model.eval()
    if likelihood is not None:
        likelihood.eval()
    return model


def check_2d_tensor(y_pred):
    """
    Helper function to check dimensions
    """
    if not isinstance(y_pred, torch.Tensor):
        return np.False_
    if y_pred.dim() == 2 and y_pred.shape[1] == 1:
        return True
    return False


def train_due(X_train, y_train, X_test, y_test, feature_extractor,
              num_outputs=1, n_inducing_points=20, kernel="RBF", lr=1e-3, batch_size=100, steps=1000, classification=False, multiclass=False):
    """
    Train a model using Deterministic Uncertainty Estimation (DUE).
    """
    # Prepare Data
    if y_train.shape[1]==1:
      y_train = y_train.squeeze()
      y_test = y_test.squeeze()

    if classification and multiclass:
        y_train_tensor = torch.from_numpy(y_train).long()
        y_test_tensor = torch.from_numpy(y_test).long()
    else:
        y_train_tensor = torch.from_numpy(y_train).float()
        y_test_tensor = torch.from_numpy(y_test).float()

    ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), y_train_tensor)
    ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), y_test_tensor)


    # ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float())
    # if multiclass:
    #   ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).long())

    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    # ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float())
    dl_test = torch.utils.data.DataLoader(ds_test, batch_size=512, shuffle=False)

    epochs = int(steps // len(dl_train)) 

    # Initialize GP and Model
    initial_inducing_points, initial_lengthscale = initial_values(ds_train, feature_extractor, n_inducing_points)

    gp = GP(
        num_outputs=num_outputs,
        initial_lengthscale=initial_lengthscale,
        initial_inducing_points=initial_inducing_points,
        kernel=kernel)
    model = DKL(feature_extractor, gp)

    if multiclass:
      likelihood = SoftmaxLikelihood(num_classes=num_outputs, mixing_weights=False)
    elif classification:
      likelihood = BernoulliLikelihood()
    else:
      likelihood = GaussianLikelihood()

    if torch.cuda.is_available():
        model = model.cuda()
        likelihood = likelihood.cuda()
        
    elbo_fn = VariationalELBO(likelihood, model.gp, num_data=len(ds_train))
    loss_fn = lambda x, y: -elbo_fn(x, y)


    # Optimizer
    parameters = [{"params": model.parameters(), "lr": lr}, {"params": likelihood.parameters(), "lr": lr}]
    optimizer = torch.optim.Adam(parameters)

    pbar = ProgressBar()

    # Training and Evaluation
    return _train_model(model, dl_train, dl_test, loss_fn, optimizer, epochs, likelihood), likelihood


def train_sngp(X_train, y_train, X_test, y_test, feature_extractor, representation_dim, classification, binary,
               num_gp_features=128, num_random_features=1024, normalize_gp_features=True,
               feature_scale=2, ridge_penalty=1, lr=1e-3, batch_size=5000, steps=100, num_outputs=1):
    """
    Train a model using Spectral-normalized Gaussian Process (SNGP).
    """

    # Prepare Data
    if y_train.shape[1]==1:
      y_train = y_train.squeeze()
      y_test = y_test.squeeze()

    ####
    if classification and (not binary):
        y_train_tensor = torch.from_numpy(y_train).long()
        y_test_tensor = torch.from_numpy(y_test).long()
    elif classification and binary:
        y_train_tensor = torch.from_numpy(y_train).float()
        y_test_tensor = torch.from_numpy(y_test).float()
    else:
        y_train_tensor = torch.from_numpy(y_train).float()
        y_test_tensor = torch.from_numpy(y_test).float()

    ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), y_train_tensor)
    ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), y_test_tensor)
    ###

    # ds_train = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train).float())
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    # ds_test = torch.utils.data.TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float())
    dl_test = torch.utils.data.DataLoader(ds_test, batch_size=512, shuffle=False)

    epochs = int(steps // len(dl_train))


    # Initialize SNGP Model
    model = Laplace(
        feature_extractor,
        representation_dim,
        num_gp_features,
        normalize_gp_features,
        num_random_features,
        num_outputs=num_outputs,
        num_data=len(ds_train),
        train_batch_size=batch_size,
        ridge_penalty=ridge_penalty,
        feature_scale=feature_scale,
    )
    if classification and binary:
        loss_fn = F.binary_cross_entropy_with_logits
    elif classification and (not binary):
        loss_fn = F.cross_entropy
    elif not classification:
        loss_fn = F.mse_loss

    if torch.cuda.is_available():
        model = model.cuda()

    parameters = [{"params": model.parameters(), "lr": lr}]

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    pbar = ProgressBar()

    # Training and Evaluation
    return _train_model(model, dl_train, dl_test, loss_fn, optimizer, epochs, classification=classification)


def density_uncertainty(X_train, Y_train, model, criterion=nn.BCELoss(), learning_rate=1e-2, epochs=50, batch_size=50, warmup_epochs=0, kl_mult=1.0, kl_beta=1.0, ll_scale=0.01):
    """
    Train a model using Density Uncertainty Layers (DUL) with MLP.
    """
    # Prepare the data
    if Y_train.shape[1]==1:
      Y_train = Y_train.squeeze()
    X_train, Y_train = torch.from_numpy(X_train).float(), torch.from_numpy(Y_train).float()
    train_dataset = TensorDataset(X_train, Y_train)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=False)

    # Prepare torch objects
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = model.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

    scheduler = LambdaLR(optimizer, lambda step: min(1.0, step / (warmup_epochs * len(train_loader))) if step < warmup_epochs * len(train_loader) else 1.0)

    for epoch in range(epochs):
        net.train()
        train_loss, train_kl, train_ll, correct, n_train = 0, 0, 0, 0, 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
          inputs, targets = inputs.to(device), targets.to(device)
          batch_size = len(inputs)
          n_train += batch_size

          mean, variance = net(inputs)  # Get both mean and variance
          targets = targets.squeeze()  # Ensure target has correct shape
          if isinstance(criterion, nn.CrossEntropyLoss):
            targets = targets.long()
          loss = criterion(mean, targets)  # Compute loss

          if isinstance(net, BayesianModel):
              if epoch + batch_idx == 0 :
                  print("BayesianModel")
              kl_div = net.kl_div() / n_train
              loss += kl_beta * kl_mult * kl_div
              train_kl += kl_div * batch_size  # Accumulate kl_div without .item()
              kl_mult = min(1.0, kl_mult + 1)

          if isinstance(net, DensityModel):
              if epoch + batch_idx == 0 :
                print("DensityModel")
              loglikelihood = net.loglikelihood()
              loss -= loglikelihood * ll_scale
              train_ll += loglikelihood * batch_size

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          scheduler.step()

          if isinstance(criterion, nn.CrossEntropyLoss):
            predicted = mean.max(dim=1)[1]
          else:  
            predicted = mean.max(dim=1)[1] if mean.dim() > 1 else (mean > 0.5).float()
          correct += predicted.eq(targets).sum().item()

          train_kl = train_kl / n_train
          train_ll = train_ll / n_train
          train_accuracy = 100. * correct / n_train

        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

    return net