from typing import Protocol

import numpy as np
from mi_estimators.mi_correlation import mutual_information_correlation
from mi_estimators.mi_hsic import hsic_estimate, hsic_loss_2
from mi_estimators.mi_mine import estimate_mine, mutual_information_mine
from causallearn.utils.cit import CIT


import torch
import torch.nn.functional as F
from sklearn.feature_selection import mutual_info_regression


class LossFunction(Protocol):
    def __call__(
        self, prediction, target, condition, params: dict, **kwargs
    ) -> int: ...


def get_loss_function(loss_type) -> LossFunction:
    if loss_type == "mse":

        def mse(prediction, target, condition, params, **kwargs):
            return F.mse_loss(prediction, target), {}

        return mse

    elif loss_type == "mse_ind":

        def mse_ind(prediction, target, condition, params, **kwargs):
            mse_prediction_target = F.mse_loss(prediction, target)
            mse_prediction_condition = F.mse_loss(prediction, condition)
            return mse_prediction_target - mse_prediction_condition, {
                "mse_prediction_target": mse_prediction_target,
                "mse_prediction_condition": mse_prediction_condition,
            }

        return mse_ind

    elif loss_type == "mse_mi":
        pred_np = pred.detach().cpu().numpy()
        condition_np = condition.detach().cpu().numpy()
        mi_score = mutual_info_regression(pred_np, condition_np)
        return F.mse_loss(pred, target) - torch.tensor(
            mi_score, dtype=torch.float32, device=pred.device
        )
    elif loss_type == "mse_mi_true":
        pred_np = pred.detach().cpu().numpy()
        condition_np = condition.detach().cpu().numpy()
        mi_score = mutual_info_regression(pred_np, condition_np)
        return F.mse_loss(pred, target) + torch.tensor(
            mi_score, dtype=torch.float32, device=pred.device
        )
    elif loss_type == "mi":
        pred_np = pred.detach().cpu().numpy()
        condition_np = condition.detach().cpu().numpy()
        mi_score = mutual_info_regression(pred_np, condition_np)
        return torch.tensor(mi_score, dtype=torch.float32, device=pred.device)
    elif loss_type == "mi_mse_pytorch":

        def mi_mse_pytorch(prediction, target, condition, params):
            a = prediction.squeeze(1)
            mi_score = mutual_information_kde(a, condition, device=prediction.device)
            return params["beta"] * mi_score, {
                "mi_score": float(mi_score.item()),
            }

        return mi_mse_pytorch
    elif loss_type == "mi_mse_pytorch_rescaled":
        a = pred.squeeze(1)
        mi_score = mutual_information_kde(a, condition, device=device)
        return F.mse_loss(pred, target) + 0.1 * mi_score
    elif loss_type == "cor_mse_pytorch":

        def cor_mse_pytorch(prediction, target, condition, params, **kwargs):
            a = prediction.squeeze(1)
            coeff = mutual_information_correlation(a, condition)
            return F.mse_loss(prediction, target) + params["beta"] * coeff, {
                "correlation": float(coeff.item()),
            }

        return cor_mse_pytorch

    elif loss_type == "mine_mse_pytorch":

        def mse_mine(prediction, target, condition, params, mine_network=None):
            beta = params.get("beta", 1.0)

            # Combine with MSE
            mine_loss_value = estimate_mine(
                mine_network, prediction, condition.unsqueeze(1)
            )
            mse = F.mse_loss(prediction, target)
            total_loss = mse + beta * mine_loss_value
            # total_loss = mse + mine_loss_value
            return total_loss, {
                "mse": mse.item(),
                "mi_estimate": mine_loss_value.item(),
            }

        return mse_mine

    elif loss_type == "mi_mse_debug":

        def mse_mine(prediction, target, condition, params, mine_network=None):
            beta = params.get("beta", 1.0)

            if mine_network is not None:
                # Combine with MSE
                mine_loss_value = estimate_mine(
                    mine_network, prediction, condition.unsqueeze(1)
                )

            correlation = mutual_information_correlation(
                prediction.squeeze(1),
                condition,
            )

            pred_np = prediction.squeeze(1).detach().cpu().numpy()
            cond_np = condition.detach().cpu().numpy()
            mi_binning = estimate_mi_binning(pred_np, cond_np)
            mi_regression = estimate_mi_sklearn_regression(pred_np, cond_np)

            hsic_score = hsic_estimate(
                prediction.detach(),
                condition.unsqueeze(1).detach(),
            )

            mse = F.mse_loss(prediction, target)
            total_loss = mse

            return total_loss, {
                "mse": mse.item(),
                "mi_estimate": (
                    mine_loss_value.item() if mine_network is not None else None
                ),
                "correlation": correlation.item(),
                "mi_binning": float(mi_binning),
                "mi_regression": float(mi_regression),
                "hsic_estimate": hsic_score.item(),
            }

        return mse_mine

    elif loss_type == "hsic_mse_debug":

        def mse_mine(prediction, target, condition, params, mine_network=None):
            beta = params.get("beta", 1.0)

            if mine_network is not None:
                # Combine with MSE
                mine_loss_value = estimate_mine(
                    mine_network, prediction, condition.unsqueeze(1)
                )

            correlation = mutual_information_correlation(
                prediction.squeeze(1),
                condition,
            )

            pred_np = prediction.squeeze(1).detach().cpu().numpy()
            cond_np = condition.detach().cpu().numpy()
            mi_binning = estimate_mi_binning(pred_np, cond_np)
            mi_regression = estimate_mi_sklearn_regression(pred_np, cond_np)

            npeet = float(ee.mi(pred_np, cond_np))
            npeet_5 = float(ee.mi(pred_np, cond_np, k=5))
            npeet_10 = float(ee.mi(pred_np, cond_np, k=10))

            hsic_score = hsic_estimate(
                prediction,
                condition.unsqueeze(1),
            )

            hsic_loss_2_score = hsic_loss_2(
                prediction,
                condition.unsqueeze(1),
            )

            mse = F.mse_loss(prediction, target)
            total_loss = mse + beta * hsic_score

            return total_loss, {
                "mse": mse.item(),
                "mi_estimate": (
                    mine_loss_value.item() if mine_network is not None else None
                ),
                "correlation": correlation.item(),
                "mi_binning": float(mi_binning),
                "mi_regression": float(mi_regression),
                "hsic_estimate": hsic_score.item(),
                "hsic2_estimate": hsic_loss_2_score.item(),
                "npeet": npeet,
                "npeet_5": npeet_5,
                "npeet_10": npeet_10,
            }

        return mse_mine

    elif loss_type == "hsic2_mse_debug":

        def mse_mine(prediction, target, condition, params, mine_network=None):
            beta = params.get("beta", 1.0)

            if mine_network is not None:
                # Combine with MSE
                mine_loss_value = estimate_mine(
                    mine_network, prediction, condition.unsqueeze(1)
                )

            correlation = mutual_information_correlation(
                prediction.squeeze(1),
                condition,
            )

            pred_np = prediction.squeeze(1).detach().cpu().numpy()
            cond_np = condition.detach().cpu().numpy()
            mi_binning = estimate_mi_binning(pred_np, cond_np)
            mi_regression = estimate_mi_sklearn_regression(pred_np, cond_np)

            hsic_score = hsic_estimate(
                prediction,
                condition.unsqueeze(1),
            )

            hsic2_score = hsic_loss_2(
                prediction,
                condition,
            )

            mse = F.mse_loss(prediction, target)
            total_loss = mse + beta * hsic2_score

            return total_loss, {
                "mse": mse.item(),
                "mi_estimate": (
                    mine_loss_value.item() if mine_network is not None else None
                ),
                "correlation": correlation.item(),
                "mi_binning": float(mi_binning),
                "mi_regression": float(mi_regression),
                "hsic_estimate": hsic_score.item(),
                "hsic2_estimate": hsic2_score.item(),
            }

        return mse_mine

    elif loss_type == "hsic_mse_pytorch":

        def hsic_mse_pytorch(prediction, target, condition, params, **kwargs):
            hsic_score = hsic_estimate(prediction, condition.unsqueeze(1))
            mse = F.mse_loss(prediction, target)

            loss = mse + params["beta"] * hsic_score
            return loss, {
                "mi_estimate": float(hsic_score.item()),
                "mse": float(mse.item()),
            }

        return hsic_mse_pytorch

    elif loss_type == "hsic_mse_pytorch_scaled":

        def hsic_mse_pytorch(prediction, target, condition, params, **kwargs):
            hsic_score = hsic_estimate(prediction, condition.unsqueeze(1))
            mse = F.mse_loss(prediction, target)

            loss = mse + params["beta"] * 1000 * hsic_score
            return loss, {
                "mi_estimate": float(hsic_score.item()),
                "mse": float(mse.item()),
            }

        return hsic_mse_pytorch

    elif loss_type == "hsic_mse_pytorch_scaled_relu":

        def hsic_mse_pytorch(prediction, target, condition, params, **kwargs):
            hsic_score = hsic_estimate(prediction, condition.unsqueeze(1))
            mse = F.mse_loss(prediction, target)

            threshold = 0.002

            loss = mse + params["beta"] * 1000 * torch.max(
                hsic_score - threshold, torch.tensor(0.0, device=prediction.device)
            )
            return loss, {
                "mi_estimate": float(hsic_score.item()),
                "mse": float(mse.item()),
            }

        return hsic_mse_pytorch

    elif loss_type == "hsic_mse_pytorch_relu":

        def hsic_mse_pytorch(prediction, target, condition, params, **kwargs):
            hsic_score = hsic_estimate(prediction, condition.unsqueeze(1))
            mse = F.mse_loss(prediction, target)

            threshold = 0.002

            loss = mse + params["beta"] * torch.max(
                hsic_score - threshold, torch.tensor(0.0, device=prediction.device)
            )
            return loss, {
                "mi_estimate": float(hsic_score.item()),
                "mse": float(mse.item()),
            }

        return hsic_mse_pytorch

    elif loss_type == "all_mi":

        def all_mi(prediction, target, condition, params, mine_network=None):
            correlation = mutual_information_correlation(
                prediction.squeeze(1),
                condition,
            )

            pred_np = prediction.squeeze(1).detach().cpu().numpy()
            cond_np = condition.detach().cpu().numpy()
            mi_binning = estimate_mi_binning(pred_np, cond_np)
            mi_regression = estimate_mi_sklearn_regression(pred_np, cond_np)

            hsic_score = hsic_estimate(
                prediction.detach(),
                condition.unsqueeze(1).detach(),
            )

            mine_score = mutual_information_mine(
                prediction.detach(),
                condition.unsqueeze(1).detach(),
            )

            data = np.column_stack((pred_np, cond_np))
            # print(data.shape)
            # print(X.shape)
            # print(Y.shape)
            kci = float(CIT(data, "kci")(0, 1))
            npeet = float(ee.mi(pred_np, cond_np))

            mse = F.mse_loss(prediction, target)
            total_loss = mse

            return total_loss, {
                "mse": mse.item(),
                "correlation": correlation.item(),
                "mi_binning": float(mi_binning),
                "mi_regression": float(mi_regression),
                "hsic_estimate": hsic_score.item(),
                "mine_estimate": mine_score.item(),
                "kci": kci,
                "npeet": npeet,
            }

        return all_mi

    elif loss_type == "mse_and_hsic":

        def mse_and_hsic(prediction, target, condition, params, mine_network=None):
            hsic_score = hsic_estimate(
                prediction.detach(),
                condition.unsqueeze(1).detach(),
            )
            mse = F.mse_loss(prediction, target)
            total_loss = mse

            return total_loss, {
                "mse": mse.item(),
                "hsic_estimate": hsic_score.item(),
            }

        return mse_and_hsic

    else:
        raise ValueError("Invalid loss type")
