from typing import cast, NamedTuple, Hashable, Callable, Optional, Any
from itertools import permutations, combinations
import functools
import math
from logging import getLogger

import numpy as np
import pandas as pd
import torch

from dcg.graph import CausalGraph
from dcg.node import CausalNode

from tqdm.auto import tqdm


NODE_LIKE = str | CausalNode
DGP_LIKE = Callable[[int, dict], pd.DataFrame]


class MeanStd(NamedTuple):
    mean: np.ndarray
    std: np.ndarray


logger = getLogger(__name__)


CACHE_REDIRECT = (
    Callable[[int], int] | Callable[[tuple[int, ...]], tuple[int, ...]]
)


def _cache_redirect(comb: tuple[int, ...]) -> tuple[int, ...]:
    return tuple(sorted(comb))


def _cache_redirect_parents(
    comb: tuple[int, ...],
    *,
    parents: tuple[int, ...]
) -> tuple[int, ...]:
    if set(comb) >= set(parents):
        return parents
    else:
        return comb


def decode_subset(s: int) -> tuple[int, ...]:
    l = []

    while s > 0:
        x = math.floor(math.log2(s))
        s -= (1 << x)
        l.append(x)

    return tuple(l)


def cache_ratio(k: int, n: int) -> float:
    """Ratio of cached coalitions for k players after n permutations"""
    return 1 - 2 ** -k * sum(
        comb * (1 - 1 / comb) ** n
        for comb in map(lambda s: math.comb(k, s), range(k + 1))
    )


def cache_ratio_find_N(k: int, alpha: float) -> int:
    """Find the first n for which F(k, n) > alpha, up to 2**k"""
    # Use binary search
    low = 0
    high = 2 ** k

    assert alpha >= cache_ratio(k, low) and alpha <= cache_ratio(k, high), (
        'alpha is not in the accepted range, '\
        'or K is too large and results in float overflow'
    )

    while high > low + 1:
        n = (low + high) // 2

        if cache_ratio(k, n) < alpha:
            low = n
        else:
            high = n

    return high


def shap(
    x: Any,
    V: int,
    f: Callable[[Any, tuple[int, ...]], np.ndarray],
    /, *,
    adaptive: bool = False,
    max_perms: int = ...,
    min_perms: int = ...,
    cache_size: Optional[int] = None,
    cache_redirect: CACHE_REDIRECT = _cache_redirect,
    rs: Optional[np.random.RandomState] = None,
    use_tqdm: bool = True
) -> MeanStd:
    """
        - x: sample(s) to explain.
        - V: number of input variables.
        - f: value function (sample, subset) -> value.
        - exact: whether to compute exact Shapley values.
            If False, uses sampling-based estimation.
        - adaptive: whether to use adaptive sampling when exact=True.
        - max_perms: maximum number of permutations for sampling-based estimation.
        - min_perms: minimum number of permutations per variable for adaptive sampling estimation.
        - cache_size: maximum number of stored items in the cache. If None, no maximum.
        - cache_redirect: function comb -> comb that returns the combination to compute.
            Useful in case some comb has the same value as comb a prior,
            such as in do-SHAP with sets >= Pa(Y); in this case the redirect
            returns Pa(Y) for any sets that fulfill the condition,
            so only the value for Pa(Y) is computed.
        - rs: random state.
    """
    if rs is None:
        rs = np.random.RandomState()


    # We define the cache here and wipe it once we leave the function
    @functools.lru_cache(maxsize=cache_size)
    def _value_f(comb: tuple[int, ...] | int):
        if isinstance(comb, int):
            comb = decode_subset(comb)

        return f(x, comb)

    def value_f(comb: tuple[int, ...]) -> np.ndarray:
        return _value_f(cache_redirect(comb))


    m = [np.zeros((1, V))]
    ss = [np.zeros((1, V))]
    s2 = [np.zeros((1, V))]
    if not adaptive:
        fact_V = math.factorial(V)
        if max_perms is ...:
            max_perms = fact_V

        if max_perms == fact_V:
            perms = np.array(list(permutations(range(V))))
            rs.shuffle(perms)

            # Check if we can store all combinations in the cache:
            n_combs = 2 ** V
            if cache_size is None or cache_size >= n_combs:
                # In that case, compute them now so tqdm is more informative:
                logger.info('Filling cache:')
                for _ in (tqdm(range(n_combs)) if use_tqdm else range(n_combs)):
                    for i in range(V + 1):
                        for comb in combinations(range(V), i):
                            value_f(comb)  # this stores it
        else:
            perms = np.stack(
                [rs.permutation(V) for _ in range(max_perms)],
                axis=0
            )

        n = 0
        for perm in (tqdm(perms, total=max_perms) if use_tqdm else perms):
            perm = tuple(map(int, perm))
            n += 1
            
            argsort = np.argsort(perm)
            diff = np.diff(
                np.stack(
                    # Note that we order the subset perm[:i]
                    # so that we can use the cache
                    [value_f(perm[:i]) for i in range(V + 1)],
                    axis=-1
                ), axis=-1
            )[..., argsort]
            m.append(m[-1] + (diff - m[-1]) / n)
            ss.append(ss[-1] + (diff - m[-2]) * (diff - m[-1]))
            s2.append(ss[-1] / max(1, n - 1))

        mean = np.stack(m[1:], axis=0)
        std = np.sqrt(
            np.stack(s2[1:], axis=0) / 
            np.arange(1, n + 1)[:, np.newaxis, np.newaxis]
        )

        return MeanStd(mean, std)
    else:
        assert max_perms is not ...
        if min_perms is ...:
            min_perms = max_perms // 2
        assert 0 < min_perms and min_perms < max_perms

        with tqdm(total=max_perms * V) as tq:
            for n in range(1, min_perms + 1):
                perm = tuple(map(int, rs.permutation(range(V))))

                argsort = np.argsort(perm)
                diff = np.diff(
                    np.stack(
                        # Note that we order the subset perm[:i]
                        # so that we can use the cache
                        [value_f(perm[:i]) for i in range(V + 1)],
                        axis=-1
                    ), axis=-1
                )[..., argsort]

                m.append(m[-1] + (diff - m[-1]) / n)
                ss.append(ss[-1] + (diff - m[-2]) * (diff - m[-1]))
                s2.append(ss[-1] / max(1, n - 1))

                tq.update(V)  # V steps
            
            ns = np.full((V,), min_perms)
            while ns.sum() < max_perms * V:
                # Select the variable with the 
                # highest average variance (across samples)
                def gain(idx):
                    s2_ = s2[ns[idx]][:, idx].mean(0)

                    return (
                        np.sqrt(s2_ / np.sqrt(ns[idx])) - 
                        np.sqrt(s2_ / np.sqrt(ns[idx] + 1))
                    )

                idx = max(range(V), key=gain)
                n = ns[idx] = ns[idx] + 1  # update ns for what comes next

                # Generate a permutation just for that variable
                perm = tuple(map(int, rs.permutation(range(V))))
                i = perm.index(idx)
                diff = value_f(perm[:i+1]) - value_f(perm[:i])

                if n == len(m):
                    # Add new rows
                    m.append(np.full(m[-1].shape, np.nan))
                    ss.append(np.full(m[-1].shape, np.nan))
                    s2.append(np.full(m[-1].shape, np.nan))
                
                # Update m, ss, s2
                m[n][:, idx] = m[n - 1][:, idx] + (diff - m[n - 1][:, idx]) / n
                ss[n][:, idx] = ss[n - 1][:, idx] + (diff - m[n - 1][:, idx]) * (diff - m[n][:, idx])
                s2[n][:, idx] = ss[n][:, idx] / max(1, n - 1)

                tq.update(1)

        mean = np.stack(m[1:], axis=0)
        std = np.sqrt(
            np.stack(s2[1:], axis=0) / 
            np.arange(1, ns.max() + 1)[:, np.newaxis, np.newaxis]
        )

        return MeanStd(mean, std)


def dgp_shap(
    x: pd.DataFrame,
    V: list[str],
    target: str,
    dgp: DGP_LIKE,
    model: Optional[Callable[[Any], np.ndarray]] = None,
    *,
    N: int = 1000,
    rs: Optional[np.random.RandomState] = None,
    **kwargs
) -> MeanStd:
    if rs is None:
        rs = np.random.RandomState()

    # Define the value function
    def f(x: pd.DataFrame, subset: tuple[int, ...]) -> np.ndarray:
        intv = dict()
        n = len(x)
        for node in map(V.__getitem__, subset):
            intv[node] = x[node].values.repeat(N)

        sample = dgp(N * n, rs=rs, **intv)
        if model is None:
            y = sample[target].values
        else:
            X = sample[V]
            y = model(X)

        return y.reshape((n, N)).mean(axis=1)

    return shap(x, len(V), f, rs=rs, **kwargs)


def marginal_shap(
    x: pd.DataFrame,
    V: list[str],
    target: str,
    model: Callable[[Any], np.ndarray],
    train_data: pd.DataFrame,
    **kwargs
) -> MeanStd:
    N = len(train_data)
    n = len(x)

    # Define the value function
    def f(x: pd.DataFrame, subset: tuple[int, ...]) -> np.ndarray:
        df = pd.concat(
            [train_data[V] for _ in range(n)],
            axis=0
        ).copy()
        
        for i, col in enumerate(map(V.__getitem__, subset)):
            df[col] = x[col].values.repeat(N, axis=0)

        y = model(df)
        return y.reshape((n, N)).mean(axis=1)

    return shap(x, len(V), f, **kwargs)


def dcg_shap(
    x: np.ndarray,
    V: tuple[NODE_LIKE],
    target: NODE_LIKE,
    graph: CausalGraph,
    *,
    mc_n: int = 1000,  # Monte Carlo samples
    **kwargs
) -> MeanStd:
    # Preprocess the given sample and move to device
    if isinstance(x, np.ndarray):
        x = torch.Tensor(x.astype(float))

    x, n = graph._preprocess_x(x)
    x = {
        v: x.to(graph.device)
        for v, x in cast(dict[CausalNode, torch.Tensor], x).items()
    }

    # Transform V and y to CausalNode instances
    V = [graph[v] for v in V]  # transform to node instances
    target = cast(CausalNode, graph[target])
    parents = tuple(sorted(V.index(node) for node in target.parents))

    # Define the value function
    def f(
        x: dict[CausalNode, torch.Tensor], subset: tuple[int, ...]
    ) -> np.ndarray:
        with torch.no_grad():
            # We'll generate N samples per x sample
            return graph.sample(
                mc_n * n, target_node=target, interventions={
                    V[i]: x[V[i]].repeat(mc_n, 1)
                    for i in subset
                }
            ).view(mc_n, -1).mean(0).cpu().numpy()

    if 'cache_redirect' not in kwargs:
        kwargs['cache_redirect'] = functools.partial(
            _cache_redirect_parents, parents=parents
        )

    return shap(x, len(V), f, **kwargs)
