import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
from sklearn.linear_model import BayesianRidge
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

import numpy as np
from typing import Optional, Tuple
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split


def learn_and_predict_nans_direct(data_manager, predictions, estimated_values) -> np.ndarray:
    """Learn a model to predict values for the nans in estimated_values, based on available values.

    Args:
        data_manager: DataManager which can provide observation features for the predictions
        predictions: The task model's predictions for the observations (ensemble_size, num_observations)
        estimated_values: The values to be estimated (pemse or bias squared ), which may 
            contain nans (num_observations,)

    Returns:
        np.ndarray: The completed version of estimated_values, with nans replaced by predictions
    """
    # Create a mask for the non-NaN values
    mask = ~np.isnan(estimated_values)

    # Get the indices of the non-NaN values
    known_indices = np.where(mask)[0]
    x_train = data_manager.full_X[known_indices]

    predictions_mean = np.mean(predictions, axis=0)  # shape (num_observations,)
    predictions_variance = np.var(predictions, axis=0)  # shape (num_observations,)

    x_train_lossmodel = np.column_stack((x_train, predictions_mean[mask], predictions_variance[mask]))
    y_train_lossmodel = estimated_values[known_indices]

    # Define kernel
    kernel = C(1.0, (1e-3, 1e3)) * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2))

    # Fit GP
    gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-6, normalize_y=True)
    gp.fit(x_train_lossmodel, y_train_lossmodel)

    # Predict for the NaN values
    x_test_lossmodel = np.column_stack(
        (data_manager.full_X[~mask], predictions_mean[~mask], predictions_variance[~mask])
    )
    y_pred, y_std = gp.predict(x_test_lossmodel, return_std=True)
    lambda_ = 2.0
    y_pred_score = y_pred + lambda_ * y_std

    # Replace NaN values in estimated_values with y_pred_score
    estimated_values[~mask] = y_pred_score
    return estimated_values


def learn_and_predict_nans_quadratic_naive(
    data_manager,
    ensemble_predictions,
    estimated_values,
    estimated_values_matrix,
):
    """Learn a model that predict entries of the co-bias/co-emse matrix, this is the naive version

    We notice that the MSE is essentially the expected value over a squared norm, which can
    also be interpreted as the inner product between two identical elements of a hilbert space.
    In which case we can consider the bias-covariance decomposition, but for non-identical
    elements of the input to derive a cobias-covariance tradeoff.

    Args:
        data_manager: DataManager which can provide observation features for the predictions
        ensemble_predictions: The predictions for the observations (ensemble_size, num_observations)
        estimated_values: The values to be estimated, which may contain nans (num_observations,)
        estimated_values_matrix: The matrix of estimated values, which may contain nans
            (num_observations, num_observations). Computed using estimated_values.T @ estimated_values

    Returns:
        np.ndarray: The completed version of estimated_values, with nans replaced by predictions
    """

    # Create a mask for the non-NaN values
    mask = ~np.isnan(estimated_values_matrix)

    # Get the indices of the non-NaN values
    known_indices = np.where(mask)  # this will be an a tuple of two arrays (i_indices, j_indices)
    unknown_indices = np.where(~mask)  # this will be an a tuple of two arrays (i_indices, j_indices)

    # Prepare inputs for the Gaussian Process Regressor
    x_i_train = data_manager.full_X[known_indices[0]]  # shape (num_observations, num_features)
    x_j_train = data_manager.full_X[known_indices[1]]  # shape (num_observations, num_features)
    predictions_mean_i = np.mean(ensemble_predictions[:, known_indices[0]], axis=0)  # shape (num_observations,)
    predictions_mean_j = np.mean(ensemble_predictions[:, known_indices[1]], axis=0)  # shape (num_observations,)
    predictions_variance_i = np.var(ensemble_predictions[:, known_indices[0]], axis=0)  # shape (num_observations,)
    predictions_variance_j = np.var(ensemble_predictions[:, known_indices[1]], axis=0)  # shape (num_observations,)
    x_train_lossmodel = np.column_stack(
        (x_i_train, x_j_train, predictions_mean_i, predictions_mean_j, predictions_variance_i, predictions_variance_j)
    )
    y_train_lossmodel = estimated_values_matrix[known_indices]  # shape (num_observations,)

    # Define kernel
    kernel = C(1.0, (1e-3, 1e3)) * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2))

    # Fit GP
    gp = GaussianProcessRegressor(kernel=kernel, alpha=1e-6, normalize_y=True)
    gp.fit(x_train_lossmodel, y_train_lossmodel)

    # Predict for the NaN values
    x_i_test_lossmodel = data_manager.full_X[unknown_indices[0]]
    x_j_test_lossmodel = data_manager.full_X[unknown_indices[1]]
    predictions_mean_i_test = np.mean(ensemble_predictions[:, unknown_indices[0]], axis=0)  # shape (num_observations,)
    predictions_mean_j_test = np.mean(ensemble_predictions[:, unknown_indices[1]], axis=0)  # shape (num_observations,)
    predictions_variance_i_test = np.var(
        ensemble_predictions[:, unknown_indices[0]], axis=0
    )  # shape (num_observations,)
    predictions_variance_j_test = np.var(
        ensemble_predictions[:, unknown_indices[1]], axis=0
    )  # shape (num_observations,)
    x_test_lossmodel = np.column_stack(
        (
            x_i_test_lossmodel,
            x_j_test_lossmodel,
            predictions_mean_i_test,
            predictions_mean_j_test,
            predictions_variance_i_test,
            predictions_variance_j_test,
        )
    )
    y_pred, y_std = gp.predict(x_test_lossmodel, return_std=True)
    lambda_ = 2.0
    y_pred_score = y_pred + lambda_ * y_std

    # Replace NaN values in estimated_values with y_pred_score
    estimated_values_matrix[unknown_indices] = y_pred_score

    # Return the diagonal of the matrix
    return np.diagonal(estimated_values_matrix)


class EmbedNet(nn.Module):
    def __init__(self, h=16, input_dim=2, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(32, h),
        )

    def forward(self, x):
        return self.net(x)


class PairwiseDot(nn.Module):
    def __init__(self, hidden_dim=16, input_dim=2):
        super().__init__()
        self.embed = EmbedNet(h=hidden_dim, input_dim=input_dim)

    def forward(self, x_i, x_j):
        v_i = self.embed(x_i)  # (batch, h)
        v_j = self.embed(x_j)  # (batch, h)
        # dot product along hidden dim:
        return (v_i * v_j).sum(dim=1, keepdim=True)  # (batch, 1)


def learn_and_predict_nans_quadratic(
    data_manager, ensemble_predictions, estimated_values, estimated_values_matrix, batch_strategy="top-k"
):
    """Learn a model that predict entries of the co-bias/co-emse matrix,

    Args:
        data_manager: DataManager which can provide observation features for the predictions
        ensemble_predictions: The predictions for the observations (ensemble_size, num_observations)
        estimated_values: The values to be estimated, which may contain nans (num_observations,)
        estimated_values_matrix: The matrix of estimated values, which may contain nans
            (num_observations, num_observations).

    Returns:
        np.ndarray: The completed version of estimated_values, with nans replaced by predictions
    """
    if torch.cuda.is_available():
        torch.set_float32_matmul_precision("high")

    # Create a mask for the non-NaN values
    mask = ~np.isnan(estimated_values_matrix)

    # Get the indices of the non-NaN values
    known_indices = np.where(mask)  # this will be an a tuple of two arrays (i_indices, j_indices)
    unknown_indices = np.where(~mask)  # this will be an a tuple of two arrays (i_indices, j_indices)

    # Avoid training the model twice on non-diagonals!
    i_k, j_k = known_indices
    upper_mask = i_k <= j_k
    i_k = i_k[upper_mask]
    j_k = j_k[upper_mask]

    x_i_np = data_manager.full_X[i_k]  # shape (num_observations, num_features)
    x_j_np = data_manager.full_X[j_k]  # shape (num_observations, num_features)
    v_i_np = np.mean(ensemble_predictions[:, i_k], axis=0)  # shape (num_observations,)
    v_j_np = np.mean(ensemble_predictions[:, j_k], axis=0)  # shape (num_observations,)
    x_i_mean_np = np.mean(ensemble_predictions[:, i_k], axis=0)  # shape (num_observations,)
    x_j_mean_np = np.mean(ensemble_predictions[:, j_k], axis=0)  # shape (num_observations,)
    x_i_np = np.column_stack((x_i_np, v_i_np, x_i_mean_np))  # shape (num_observations, num_features + 2)
    x_j_np = np.column_stack((x_j_np, v_j_np, x_j_mean_np))  # shape (num_observations, num_features + 2)
    y_np = estimated_values_matrix[i_k, j_k]  # (num_observations,)

    print(f"### We have {len(i_k)} known pairs of indices to train on, and {len(unknown_indices[0])} unknown pairs.")

    # Split the data into training and validation sets
    x_i_train, x_i_val, x_j_train, x_j_val, y_train, y_val = train_test_split(
        x_i_np, x_j_np, y_np, test_size=0.15, random_state=42
    )

    # Standardize the data
    scaler = StandardScaler()
    scaler.fit(x_i_train)
    x_i_train = scaler.transform(x_i_train)
    x_i_val = scaler.transform(x_i_val)
    scaler.fit(x_j_train)
    x_j_train = scaler.transform(x_j_train)
    x_j_val = scaler.transform(x_j_val)

    # Convert the data to PyTorch tensors
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x_i_train = torch.from_numpy(x_i_train).float().to(device)
    x_j_train = torch.from_numpy(x_j_train).float().to(device)
    y_train = torch.from_numpy(y_train).float().to(device).unsqueeze(1)
    x_i_val = torch.from_numpy(x_i_val).float().to(device)
    x_j_val = torch.from_numpy(x_j_val).float().to(device)
    y_val = torch.from_numpy(y_val).float().to(device).unsqueeze(1)
    train_dataset = TensorDataset(x_i_train, x_j_train, y_train)
    val_dataset = TensorDataset(x_i_val, x_j_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4096, shuffle=False)

    # Create the model
    h_dim = 16
    input_dim = x_i_train.shape[1]  # number of features
    model = torch.compile(PairwiseDot(hidden_dim=h_dim, input_dim=input_dim).to(device), mode="default")
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)

    val_fraction = 0.2
    patience = 200
    max_epochs = 2000

    best_val_loss = float("inf")
    best_model = None
    epochs_no_improvement = 0
    for epoch in range(max_epochs):
        model.train()
        total_loss = 0.0
        for xb_i, xb_j, yb in train_loader:
            optimizer.zero_grad()
            preds = model(xb_i, xb_j)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * xb_i.size(0)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for xb_i, xb_j, yb in val_loader:
                preds = model(xb_i, xb_j)
                loss = criterion(preds, yb)
                val_loss += loss.item() * xb_i.size(0)

        val_loss /= len(val_loader.dataset)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model.state_dict()
            epochs_no_improvement = 0
        else:
            epochs_no_improvement += 1

        if epochs_no_improvement >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
    
    # Load the best model
    model.load_state_dict(best_model)

    model.eval()
    with torch.no_grad():
        i_u, j_u = unknown_indices
        x_i_up = data_manager.full_X[i_u]
        x_j_up = data_manager.full_X[j_u]
        v_i_up = np.mean(ensemble_predictions[:, i_u], axis=0)
        v_j_up = np.mean(ensemble_predictions[:, j_u], axis=0)
        x_i_mean_up = np.mean(ensemble_predictions[:, i_u], axis=0)
        x_j_mean_up = np.mean(ensemble_predictions[:, j_u], axis=0)
        x_i_up = np.column_stack((x_i_up, v_i_up, x_i_mean_up))
        x_j_up = np.column_stack((x_j_up, v_j_up, x_j_mean_up))
        x_i_up = scaler.transform(x_i_up)
        x_j_up = scaler.transform(x_j_up)
        x_i_up = torch.from_numpy(x_i_up).float().to(device)
        x_j_up = torch.from_numpy(x_j_up).float().to(device)

        y_pred = model(x_i_up, x_j_up).cpu().numpy().ravel()

    M_filled = estimated_values_matrix.copy()
    M_filled[i_u, j_u] = y_pred
    M_filled[j_u, i_u] = y_pred

    if batch_strategy == "top-k":
        return np.diagonal(M_filled)
    else:
        return np.diagonal(M_filled), M_filled
