import numpy as np
import torch as th
from scipy.optimize import minimize
from scipy.stats import kurtosis
from torch.nn import functional as F
from utils.constants import EPS


def weighted_mse_loss(values, targets, weights, loss_type):
    if loss_type == "mse":
        loss_func = F.mse_loss
    elif loss_type == "mae":
        loss_func = F.l1_loss
    elif loss_type == "huber":
        loss_func = F.huber_loss
    else:
        raise NotImplementedError(f"Unknown loss type: {loss_type}")
    loss = sum(loss_func(values[:, i, 0], targets[:, i], reduction="none") * weights for i in range(values.shape[1]))
    return loss.mean()


def optimal_xi(variance, batch_size, min_batch_size):
    def effective_batch_size(variance):
        inverse_variance = np.power(variance, -1)
        weights = inverse_variance / np.sum(inverse_variance)
        return 1 / np.sum(np.power(weights, 2))

    xi = 0
    if effective_batch_size(variance) < (min_batch_size := min(batch_size - 1, min_batch_size)):
        epsilon = minimize(lambda xi: abs(effective_batch_size(variance + abs(xi)) - min_batch_size), 0, method="Nelder-Mead", options={"fatol": 1, "maxiter": 100})
        if (maybe_xi := epsilon.x) is not None:
            xi = abs(maybe_xi[0])
    return xi


def biv_loss(values, targets, gamma, min_batch_size, loss_type):
    batch_size, *_ = values.shape
    approx = values[..., 0].detach().cpu().numpy()
    variance = np.var(approx, axis=-1, ddof=1).clip(EPS) * gamma**2
    inverse_variance = np.power(variance + optimal_xi(variance, batch_size, min_batch_size), -1)
    biv_weights = th.from_numpy(inverse_variance / np.sum(inverse_variance)).to(values.device)
    return weighted_mse_loss(values, targets, biv_weights, loss_type)


def biev_loss(values, targets, min_batch_size, loss_type):
    batch_size, ensemble_size, _ = values.shape
    td_errors = (values[..., 0] - targets).detach().cpu().numpy()
    variance = (np.var(td_errors, axis=-1, ddof=0) / (kurtosis(td_errors, axis=-1, bias=False) / ensemble_size + (ensemble_size + 1) / (ensemble_size - 1))).clip(EPS)
    inverse_variance = np.power(variance + optimal_xi(variance, batch_size, min_batch_size), -1)
    biev_weights = th.from_numpy(inverse_variance / np.sum(inverse_variance)).to(values.device)
    return weighted_mse_loss(values, targets, biev_weights, loss_type)
