from typing import Callable, Any

import numpy as np
import pandas as pd
import torch
from torch.distributions.transforms import Transform, constraints

from omegaconf import DictConfig

# NOTE: this uses a modified version of Zuko, found in the project's root folder
import zuko.flows as zflows
from causal_nf.modules.causal_nf import CausalNormalizingFlow

from do_shap.shap import shap, MeanStd


softplus = torch.nn.functional.softplus
logsigmoid = lambda x, alpha=1., **kwargs: -softplus(-alpha * x, **kwargs)


def softplus_inv(x, eps=1e-6, threshold=20.):
    """Compute the softplus inverse."""
    x = x.clamp(0.)
    y = torch.zeros_like(x)

    idx = x < threshold
    # We deliberately ignore eps to avoid -inf
    y[idx] = torch.log(torch.exp(x[idx] + eps) - 1)
    y[~idx] = x[~idx]

    return y


class DomainTransform(Transform):
    """
    Adjust the domains of each variable separately.

    U: positive
    Z: [0, 1]
    X: positive
    A: positive
    B: all reals
    C: all reals
    """

    domain = constraints.real
    codomain = constraints.real
    bijective = True
    sign = +1

    def _call(self, x):
        """
        Abstract method to compute forward transformation.
        """
        u, z, x, a, b, c = x.unbind(dim=1)

        # Transform domain to all reals
        u = softplus_inv(u)
        z = torch.logit(z)
        x = softplus_inv(x)
        a = softplus_inv(a)

        y = torch.stack([u, z, x, a, b, c], dim=1)        

        return y

    def _inverse(self, y):
        """
        Abstract method to compute inverse transformation.
        """
        u, z, x, a, b, c = y.unbind(dim=1)

        # Transform from all reals back to their domain
        u = softplus(u)
        z = torch.sigmoid(z)
        x = softplus(x)
        a = softplus(a)

        x = torch.stack([u, z, x, a, b, c], dim=1)        

        return x

    def log_abs_det_jacobian(self, x, y):
        """
        Computes the log det jacobian `log |dy/dx|` given input and output.
        """
        # Note that we're taking transformed values for this
        u, z, x, a, b, c = y.unbind(dim=1)

        # Transform domain to all reals
        u = logsigmoid(u)  # softplus
        z = 2 * logsigmoid(z) - z  # sigmoid
        x = logsigmoid(x)  # softplus
        a = logsigmoid(a)  # softplus
        b = torch.zeros_like(b)  # identity
        c = torch.zeros_like(c)  # identity

        return -torch.stack([u, z, x, a, b, c], dim=1)


class NormalizeTransform(Transform):
    """
    Normalize tensor given normalization constants.
    """

    domain = constraints.real
    codomain = constraints.real
    bijective = True
    sign = +1

    def __init__(
        self,
        loc: torch.Tensor,
        scale: torch.Tensor,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.loc = loc
        self.scale = scale
        self.log_scale = torch.log(scale)

    def to(self, device):
        self.loc = self.loc.to(device)
        self.scale = self.scale.to(device)
        self.log_scale = self.log_scale.to(device)

    def _call(self, x: torch.Tensor) -> torch.Tensor:
        self.to(x.device)
        return (x - self.loc) / self.scale

    def _inverse(self, y: torch.Tensor) -> torch.Tensor:
        self.to(y.device)
        return y * self.scale + self.loc

    def log_abs_det_jacobian(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> torch.Tensor:
        return -self.log_scale.expand(x.shape).to(x.device)


def create_cnf(
    X: torch.Tensor, adjacency: torch.Tensor, cfg: DictConfig
) -> CausalNormalizingFlow:
    activation = {
        "relu": torch.nn.ReLU,
        "elu": torch.nn.ELU,
        "lrelu": torch.nn.LeakyReLU,
        "sigmoid": torch.nn.Sigmoid,
    }[cfg.activation]

    # Create base NSF flow
    flow = zflows.NSF(
        features=adjacency.shape[0],
        context=0,
        bins=cfg.bins,
        transforms=cfg.num_layers,
        hidden_features=cfg.hidden_features,
        adjacency=adjacency,
        base_to_data=cfg.base_to_data,
        base_distr=cfg.base_distr,
        learn_base=cfg.learn_base,
        activation=activation,
    )

    # Since NSF only transforms [-5, 5], we need to adapt the domain of X first
    flow.transforms.insert(0, zflows.Unconditional(DomainTransform))

    # And then normalize it
    # (using normalization constants based on the training dataset)
    Xt = flow.transforms[0]()(X)
    loc, scale = Xt.mean(0), Xt.std(0)
    flow.transforms.insert(1, zflows.Unconditional(
        NormalizeTransform,
        loc=loc, scale=scale
    ))

    model = CausalNormalizingFlow(flow)

    return model


def cnf_loss(module, batch):
    X, = batch

    return -module.log_prob(X)


def compute_cnf_density(
    cnf: CausalNormalizingFlow, col: str, x: np.ndarray, *,
    df_train: pd.DataFrame
) -> np.ndarray:
    col_idx = list(df_train.columns).index(col)
    X = df_train.values.copy()

    with torch.no_grad():
        N, n = len(X), len(x)
        X = X.repeat(n, axis=0)
        X[:, col_idx] = x[np.newaxis, :].repeat(N, axis=0).flatten()

        # return np.exp(
        #     log_mean_exp_trick(
        #         cnf.log_prob(torch.Tensor(X)).view(n, N),
        #         dim=1
        #     ).cpu().numpy()
        # )

        return torch.exp(cnf.log_prob(torch.Tensor(X)).view(n, N)).mean(dim=1).cpu().numpy()


def cnf_intervene(cnf, N: int, intv: list[tuple[int, torch.Tensor]], use_mps: bool = False) -> torch.Tensor:
    n_flow = cnf.flow()

    # Generate exogenous noise
    z1 = n_flow.base.rsample((N,))

    # Apply interventions in topological order (sort by index)
    intv = sorted(intv)  # this sorts by index first

    for index, value in intv:
        # Broadcast value if necessary
        if len(value) < N:
            assert not (N % len(value))
            value = value.repeat(N // len(value))
        assert len(value) == N

        if use_mps:
            value = value.to('mps')

        # Get distributional x1
        x1 = n_flow.transform.inv(z1)

        # Apply intervention
        x1[:, index] = value

        # Get z so that it results in the intervention
        z2 = n_flow.transform(x1)
        z1[:, index] = z2[:, index]  # apply changes only to index

    x2 = n_flow.transform.inv(z1)

    return x2


def cnf_shap(
    x: pd.DataFrame,
    V: list[str],
    model: Callable[[Any], np.ndarray],
    cnf: CausalNormalizingFlow,
    cnf_colnames: list[str],
    *,
    mc_n: int = 1000,  # Monte Carlo samples
    use_mps: bool = False,
    **kwargs
) -> MeanStd:
    X = torch.Tensor(x.values.astype(float))

    N = mc_n
    n = len(X)

    # Define the value function
    def f(X: torch.Tensor, subset: tuple[int]) -> np.ndarray:
        with torch.no_grad():
            # We'll generate N samples per x sample
            intv = [(i, X[:, i]) for i in subset]
            intv_df = pd.DataFrame(
                cnf_intervene(cnf, N * n, intv, use_mps=use_mps).cpu().numpy(),
                columns=cnf_colnames
            )

            # Apply it to the model
            res = model(intv_df[V])
            # And aggregate for each individual x sample
            return res.reshape((N, n)).mean(0)

    return shap(X, len(V), f, **kwargs)
