from torch.utils.data import Dataset
import itertools
import numpy as np
import torch
from torch.distributions import Categorical

from typing import Callable
import os
import random
import numpy as np
from math import comb
from itertools import product

from torch import Tensor
import torch.nn as nn
from typing import Callable



def set_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_kernelshap_weight(m: int, N: int):
    return (N - 1) / (comb(N, m) * (N - m) * m)

def get_shapley_weight(m: int, N: int):
    assert m < N - 0.5
    return 1 / (comb(N, m) * (N - m))

def compute_shapley_exactly(data: Tensor, model: Callable) -> Tensor:
    """
    Args:
        data (Tensor): The size of data is [batch_size, feat_dim].
        model (Callable): The callable function or class to produce the output [batch_size, out_dim].
    """
    device = data.device
    in_dim = data.size(1)
    N = data.size(0)

    # generate all masks
    masks = torch.stack([torch.tensor(_m) for _m in product(*([[0, 1]] * in_dim))], dim=0).float().to(device) # numM x inD
    len_mask = torch.sum(masks, dim=1).long()

    # generate weight
    meta_weight = torch.tensor([
        get_shapley_weight(_m, in_dim) for _m in range(in_dim)
    ]).to(device) # inD
    weights = []
    for i in range(in_dim):
        weights_j = torch.zeros(2 ** in_dim).to(device)
        with_i_mask = masks[:, i] > 0.5
        weights_j[with_i_mask] = meta_weight[
            len_mask[with_i_mask] - 1
        ]
        without_i_mask = ~with_i_mask
        weights_j[without_i_mask] = - meta_weight[
            len_mask[without_i_mask]
        ]
        weights.append(weights_j)
    weights = torch.stack(weights, dim=0) # inD x numM

    # compute the shapley value
    with torch.no_grad():
        masked_data = data.unsqueeze(1) * masks.unsqueeze(0) # N x numM x inD
        masked_data = masked_data.reshape(-1, in_dim) # (N x numM) x inD
        output = model(masked_data) # (N x numM) x outD
        out_dim = output.size(-1)
        output = output.view(N, -1, out_dim).detach() # N x numM x outD
    shapleys = torch.einsum("mj,ijk->ikm", weights, output)
    return shapleys

def get_masks(in_dim: int) -> Tensor:
    masks = torch.stack([torch.tensor(_m) for _m in product(*([[0, 1]] * in_dim))], dim=0).float() # numM x inD
    return masks

def get_weights(len_mask: Tensor, in_dim: int) -> Tensor:
    meta_weight = torch.tensor([
        (in_dim - 1) / (in_dim - _i) / _i / comb(in_dim, _i) for _i in range(1, in_dim)
    ])
    weights = meta_weight[(len_mask - 1).long()].unsqueeze(-1)
    return weights

def weighted_shapley_loss(shapleys, data, masks, weights, 
                          model: nn.Module):
    """
    Compute the loss of weighted OLS of the Shapley values.

    Returns:
        loss_pred (Tensor): The size is [1].
    """
    in_dim = data.size(-1)
    # compute the loss-pred
    with torch.no_grad():
        N = data.size(0)
        outbase = model(torch.zeros_like(data[0:1])).detach() # 1 x outD
        out_dim = outbase.size(-1)
        masked_data = data.unsqueeze(1) * masks.unsqueeze(0) # N x numM x inD
        masked_data = masked_data.reshape(-1, in_dim) # (N x numM) x inD
        output = model(masked_data).view(N, -1, out_dim).detach() # N x numM x outD
        output -= outbase.unsqueeze(0)
    prod_shap = torch.einsum("ijk,mk->imj", shapleys, masks) # N x numM x outD
    diff_shap = torch.pow(output - prod_shap, 2)
    loss_per = torch.sum(diff_shap * weights, dim=1) # N x outD
    loss_pred = torch.mean(loss_per)
    return loss_pred

def normalize_shapley(shapleys, model, data):
    """
    Normalize the shapley value.
    """
    in_dim = data.size(-1)
    outbase = model(torch.zeros_like(data[0:1])) # 1 x outD
    output_all = (model(data) - outbase).detach() # N x outD
    sum_shap = torch.sum(shapleys, dim=-1) # N x outD
    shapleys -= (sum_shap - output_all).unsqueeze(-1) / in_dim
    return shapleys

def get_phi_matrix(in_dim: int) -> Tensor:
    """
    Get the phi matrix for the linear regression.
    """
    def _get_coefa(m):
        return (m - 1) / m * sum([1/_m for _m in range(1, m)])
    coefa = _get_coefa(in_dim)
    coefb = coefa - (in_dim - 1) / in_dim

    phi_mat = torch.ones(in_dim, in_dim) * coefb
    phi_mat = phi_mat + (in_dim - 1) / in_dim * torch.eye(in_dim)
    return phi_mat

def get_phi_matrix_inv(in_dim: int) -> Tensor:
    """
    Get the inverse of the phi matrix.
    """
    def _get_coefa(m):
        return (m - 1) / m * sum([1/_m for _m in range(1, m)])
    coefa = _get_coefa(in_dim)
    coefb = coefa - (in_dim - 1) / in_dim
    coefc = (in_dim ** 2) * coefb / ((in_dim - 1) * (in_dim ** 2 * coefb + in_dim - 1))

    phi_mat_inv = in_dim / (in_dim - 1) * torch.eye(in_dim) - torch.ones(in_dim, in_dim) * coefc
    return phi_mat_inv

def set_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_kernelshap_weight(m: int, N: int):
    return (N - 1) / (comb(N, m) * (N - m) * m)

def get_shapley_weight(m: int, N: int):
    assert m < N - 0.5
    return 1 / (comb(N, m) * (N - m))

def get_shapley_weight_mat(in_dim: int, masks: Tensor) -> Tensor:
    """
    Get the weight matrix for the Shapley value.
    """
    device = masks.device
    len_mask = torch.sum(masks, dim=1).long()
    # generate weight
    meta_weight = torch.tensor([
        get_shapley_weight(_m, in_dim) for _m in range(in_dim)
    ]).to(device) # inD
    # generate the matrix
    pos_weight = torch.cat([torch.tensor([0.], device=device), meta_weight[len_mask[1:] - 1]]) # numM
    neg_weight = torch.cat([meta_weight[len_mask[:-1]], torch.tensor([0.], device=device)]) # numM
    weights = pos_weight.unsqueeze(-1) * masks - neg_weight.unsqueeze(-1) * (1 - masks) # numM x inD
    weights = weights.t() # inD x numM
    return weights

def compute_shapley_exactly(data: Tensor, model: Callable) -> Tensor:
    """
    Args:
        data (Tensor): The size of data is [batch_size, feat_dim].
        model (Callable): The callable function or class to produce the output [batch_size, out_dim].
    """
    device = data.device
    in_dim = data.size(1)
    N = data.size(0)

    # generate all masks
    masks = torch.stack([torch.tensor(_m) for _m in product(*([[0, 1]] * in_dim))], dim=0).float().to(device) # numM x inD
    len_mask = torch.sum(masks, dim=1).long()

    # generate weight
    meta_weight = torch.tensor([
        get_shapley_weight(_m, in_dim) for _m in range(in_dim)
    ]).to(device) # inD
    weights = []
    for i in range(in_dim):
        weights_j = torch.zeros(2 ** in_dim).to(device)
        with_i_mask = masks[:, i] > 0.5
        weights_j[with_i_mask] = meta_weight[
            len_mask[with_i_mask] - 1
        ]
        without_i_mask = ~with_i_mask
        weights_j[without_i_mask] = - meta_weight[
            len_mask[without_i_mask]
        ]
        weights.append(weights_j)
    weights = torch.stack(weights, dim=0) # inD x numM

    # compute the shapley value
    with torch.no_grad():
        masked_data = data.unsqueeze(1) * masks.unsqueeze(0) # N x numM x inD
        masked_data = masked_data.reshape(-1, in_dim) # (N x numM) x inD
        output = model(masked_data) # (N x numM) x outD
        out_dim = output.size(-1)
        output = output.view(N, -1, out_dim).detach() # N x numM x outD
    shapleys = torch.einsum("mj,ijk->ikm", weights, output)
    return shapleys

        
class ExactShapleySampler:
    '''
    For sampling player subsets from the Shapley distribution.

    Args:
      num_players: number of players.
    '''

    def __init__(self, num_players, device=None):
        '''
        For sampling player subsets from the Shapley distribution.

        Args:
            num_players: number of players.
        '''
        self.device = device if device is not None else torch.device('cpu')
        arange = torch.arange(1, num_players)

        w = 1 / (arange * (num_players - arange))
        self.w_sum = torch.sum(w)
        self.w = w / self.w_sum
        self.categorical = Categorical(probs=self.w)
        self.tril = torch.tril(
            torch.ones(num_players - 1, num_players, dtype=torch.float32, device=self.device),
            diagonal=0)
        
    def get_weight_sum(self):
        return self.w_sum
    def sample(self, batch_size, paired_sampling=None):
        '''
        Generate sample.

        Args:
          batch_size: number of samples.
          paired_sampling: whether to use paired sampling.
        '''
        num_included = (1 + self.categorical.sample([batch_size])).to(self.device)
        S = self.tril[num_included - 1]
        indices = torch.argsort(torch.rand_like(S), dim=-1)
        S = torch.gather(S, dim=-1, index=indices)
        if paired_sampling:
            odd_rows = torch.arange(1, batch_size, step=2, device=self.device)
            # Flip the values in odd rows
            S[odd_rows] = 1 - S[odd_rows - 1]
        return S

class DatasetRepeat(Dataset):
    '''
    A wrapper around multiple datasets that allows repeated elements when the
    dataset sizes don't match. The number of elements is the maximum dataset
    size, and all datasets must be broadcastable to the same size.

    Args:
      datasets: list of dataset objects.
    '''

    def __init__(self, datasets):
        # Get maximum number of elements.
        assert np.all([isinstance(dset, Dataset) for dset in datasets])
        items = [len(dset) for dset in datasets]
        num_items = np.max(items)

        # Ensure all datasets align.
        # assert np.all([num_items % num == 0 for num in items])
        self.dsets = datasets
        self.num_items = num_items
        self.items = items

    def __getitem__(self, index):
        assert 0 <= index < self.num_items
        return_items = [dset[index % num] for dset, num in
                        zip(self.dsets, self.items)]
        return tuple(itertools.chain(*return_items))

    def __len__(self):
        return self.num_items