import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pprint import pprint
import doctest
from typing import Union, Iterable, List, Tuple
from tqdm import tqdm


def generate_all_masks(length: int) -> list:
    masks = list(range(2**length))
    masks = [np.binary_repr(mask, width=length) for mask in masks]
    masks = [[bool(int(item)) for item in mask] for mask in masks]
    return masks


def set_to_index(A):
    '''
    convert a boolean mask to an index
    :param A: <np.ndarray> bool (n_dim,)
    :return: an index

    [In] set_to_index(np.array([1, 0, 0, 1, 0]).astype(bool))
    [Out] 18
    '''
    assert len(A.shape) == 1
    A_ = A.astype(int)
    return np.sum([A_[-i-1] * (2 ** i) for i in range(A_.shape[0])])


def is_A_subset_B(A, B):
    '''
    Judge whether $A \subseteq B$ holds
    :param A: <numpy.ndarray> bool (n_dim, )
    :param B: <numpy.ndarray> bool (n_dim, )
    :return: Bool
    '''
    assert A.shape[0] == B.shape[0]
    return np.all(np.logical_or(np.logical_not(A), B))


def is_A_subset_Bs(A, Bs):
    '''
    Judge whether $A \subseteq B$ holds for each $B$ in 'Bs'
    :param A: <numpy.ndarray> bool (n_dim, )
    :param Bs: <numpy.ndarray> bool (n, n_dim)
    :return: Bool
    '''
    assert A.shape[0] == Bs.shape[1]
    is_subset = np.all(np.logical_or(np.logical_not(A), Bs), axis=1)
    return is_subset


def select_subset(As, B):
    '''
    Select A from As that satisfies $A \subseteq B$
    :param As: <numpy.ndarray> bool (n, n_dim)
    :param B: <numpy.ndarray> bool (n_dim, )
    :return: a subset of As
    '''
    assert As.shape[1] == B.shape[0]
    is_subset = np.all(np.logical_or(np.logical_not(As), B), axis=1)
    return As[is_subset]


def set_minus(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    '''
    calculate A/B
    :param A: <numpy.ndarray> bool (n_dim, )
    :param B: <numpy.ndarray> bool (n_dim, )
    :return: A\B

    >>> set_minus(A=np.array([1, 1, 0, 0, 1, 1], dtype=bool), B=np.array([1, 0, 0, 0, 0, 1], dtype=bool))
    array([False,  True, False, False,  True, False])

    >>> set_minus(A=np.array([1, 1, 0, 0, 1, 1], dtype=bool), B=np.array([1, 0, 1, 0, 0, 1], dtype=bool))
    array([False,  True, False, False,  True, False])
    '''
    assert A.shape[0] == B.shape[0] and len(A.shape) == 1 and len(B.shape) == 1
    A_ = A.copy()
    A_[B] = False
    return A_


def get_subset(A):
    '''
    Generate the subset of A
    :param A: <numpy.ndarray> bool (n_dim, )
    :return: subsets of A

    >>> get_subset(np.array([1, 0, 0, 1, 0, 1], dtype=bool))
    array([[False, False, False, False, False, False],
           [False, False, False, False, False,  True],
           [False, False, False,  True, False, False],
           [False, False, False,  True, False,  True],
           [ True, False, False, False, False, False],
           [ True, False, False, False, False,  True],
           [ True, False, False,  True, False, False],
           [ True, False, False,  True, False,  True]])
    '''
    assert len(A.shape) == 1
    n_dim = A.shape[0]
    n_subsets = 2 ** A.sum()
    subsets = np.zeros(shape=(n_subsets, n_dim)).astype(bool)
    subsets[:, A] = np.array(generate_all_masks(A.sum()))
    return subsets


def flatten(x):
    '''

    Flatten an irregular list of lists

    Reference <https://stackoverflow.com/questions/2158395/flatten-an-irregular-list-of-lists>

    [In]  flatten(((1, 2), 3, 4)) -- Note: (with many brackets) x = ( (1, 2) , 3 , 4 )
    [Out] (1, 2, 3, 4)

    :param x:
    :return:
    '''
    if isinstance(x, Iterable):
        return tuple([a for i in x for a in flatten(i)])
    else:
        return [x]



def generate_subset_masks(set_mask, all_masks):
    '''
    For a given S, generate its subsets L's, as well as the indices of L's in [all_masks]
    :param set_mask:
    :param all_masks:
    :return: the subset masks, the bool indice
    '''
    set_mask_ = set_mask.expand_as(all_masks)
    is_subset = torch.logical_or(set_mask_, torch.logical_not(all_masks))
    is_subset = torch.all(is_subset, dim=1)
    return all_masks[is_subset], is_subset


def generate_reverse_subset_masks(set_mask, all_masks):
    '''
    For a given S, with subsets L's, generate N\L, as well as the indices of L's in [all_masks]
    :param set_mask:
    :param all_masks:
    :return:
    '''
    set_mask_ = set_mask.expand_as(all_masks)
    is_rev_subset = torch.logical_or(set_mask_, all_masks)
    is_rev_subset = torch.all(is_rev_subset, dim=1)
    return all_masks[is_rev_subset], is_rev_subset


def generate_set_with_intersection_masks(set_mask, all_masks):
    '''
    For a given S, generate L's, s.t. L and S have intersection as well as the indices of L's in [all_masks]
    :param set_mask:
    :param all_masks:
    :return:
    '''
    set_mask_ = set_mask.expand_as(all_masks)
    have_intersection = torch.logical_and(set_mask_, all_masks)
    have_intersection = torch.any(have_intersection, dim=1)
    return all_masks[have_intersection], have_intersection


def get_reward(values, selected_dim, **kwargs):
    if selected_dim == "max":
        values = values[:, torch.argmax(values[-1])]  # select the predicted dimension, by default
    elif selected_dim == "0":
        values = values[:, 0]
    elif selected_dim == "gt":
        assert "gt" in kwargs.keys()
        gt = kwargs["gt"]
        values = values[:, gt]  # select the ground-truth dimension
    elif selected_dim == "gt-log-odds":
        assert "gt" in kwargs.keys()
        gt = kwargs["gt"]
        eps = 1e-7
        values = torch.softmax(values, dim=1)
        values = values[:, gt]
        values = torch.log(values / (1 - values + eps) + eps)
    elif selected_dim == "max-log-odds":
        eps = 1e-7
        values = torch.softmax(values, dim=1)
        values = values[:, torch.argmax(values[-1])]
        values = torch.log(values / (1 - values + eps) + eps)
    elif selected_dim == "gt-logistic-odds":
        assert "gt" in kwargs.keys()
        gt = kwargs["gt"]
        assert gt == 0 or gt == 1
        eps = 1e-7
        assert len(values.shape) == 2 and values.shape[1] == 1
        values = torch.sigmoid(values)[:, 0]
        if gt == 0:
            values = 1 - values
        else:
            values = values
        values = torch.log(values / (1 - values + eps) + eps)
    elif selected_dim == "logistic-odds":
        eps = 1e-7
        assert len(values.shape) == 2 and values.shape[1] == 1
        values = torch.sigmoid(values)[:, 0]
        if torch.round(values[-1]) == 0.:
            values = 1 - values
        else:
            values = values
        values = torch.log(values / (1 - values + eps) + eps)
    elif selected_dim == "gt-prob-log-odds":
        assert "gt" in kwargs.keys()
        gt = kwargs["gt"]
        assert gt == 0 or gt == 1
        eps = 1e-7
        assert len(values.shape) == 2 and values.shape[1] == 1
        values = values[:, 0]
        if gt == 0:
            values = 1 - values
        else:
            values = values
        values = torch.log(values / (1 - values + eps) + eps)
    elif selected_dim == "prob-log-odds":
        eps = 1e-7
        assert len(values.shape) == 2 and values.shape[1] == 1
        values = values[:, 0]
        if torch.round(values[-1]) == 0.:
            values = 1 - values
        else:
            values = values
        values = torch.log(values / (1 - values + eps) + eps)
    elif selected_dim == None:
        values = values
    else:
        raise Exception(f"Unknown [selected_dim] {selected_dim}.")

    return values


def calculate_all_subset_outputs_tabular(
        model: nn.Module,
        input: torch.Tensor,
        baseline: torch.Tensor,
        calc_bs: Union[None, int] = None,
):
    device = input.device

    assert len(input.shape) == 1
    n_attributes = input.shape[0]

    masks = torch.BoolTensor(generate_all_masks(n_attributes)).to(device)
    masked_inputs = torch.where(masks, input.expand_as(masks), baseline.expand_as(masks))

    if calc_bs is None:
        calc_bs = masks.shape[0]

    outputs = []
    # for batch_id in tqdm(range(int(np.ceil(masks.shape[0] / calc_bs))), ncols=100, desc="Inferencing"):
    for batch_id in range(int(np.ceil(masks.shape[0] / calc_bs))):
        outputs.append(model(masked_inputs[batch_id * calc_bs:batch_id * calc_bs + calc_bs]))
    outputs = torch.cat(outputs, dim=0)

    return masks, outputs


def get_mask_input_func(grid_width: int):

    def generate_masked_input(image: torch.Tensor, baseline: torch.Tensor, grid_indices_list: List):
        device = image.device
        _, image_channel, image_height, image_width = image.shape
        grid_num_h = int(np.ceil(image_height / grid_width))
        grid_num_w = int(np.ceil(image_width / grid_width))
        grid_num = grid_num_h * grid_num_w

        batch_size = len(grid_indices_list)
        masks = torch.zeros(batch_size, image_channel, grid_num)
        for i in range(batch_size):
            grid_indices = flatten(grid_indices_list[i])
            masks[i, :, list(grid_indices)] = 1

        masks = masks.view(masks.shape[0], image_channel, grid_num_h, grid_num_w)
        masks = F.interpolate(
            masks.clone(),
            size=[grid_width * grid_num_h, grid_width * grid_num_w],
            mode="nearest"
        ).float()
        masks = masks[:, :, :image_height, :image_width].to(device)

        expanded_image = image.expand(batch_size, image_channel, image_height, image_width).clone()
        expanded_baseline = baseline.expand(batch_size, image_channel, image_height, image_width).clone()
        masked_image = expanded_image * masks + expanded_baseline * (1 - masks)

        return masked_image

    return generate_masked_input


def calculate_all_subset_outputs_image(
        model: nn.Module,
        input: torch.Tensor,
        baseline: torch.Tensor,
        grid_width: Union[None, int] = None,
        calc_bs: Union[None, int] = None,
        all_players: Union[None, tuple] = None,
        background: Union[None, tuple] = None,
):
    device = input.device
    if len(input.shape) == 3:
        input = input.unsqueeze(0)
    if len(baseline.shape) == 3:
        baseline = baseline.unsqueeze(0)
    assert len(input.shape) == 4
    assert len(baseline.shape) == 4

    _, image_channel, image_height, image_width = input.shape
    grid_num_h = int(np.ceil(image_height / grid_width))
    grid_num_w = int(np.ceil(image_width / grid_width))
    grid_num = grid_num_h * grid_num_w

    mask_input_fn = get_mask_input_func(grid_width=grid_width)

    if all_players is None:
        n_players = grid_num
        all_players = np.arange(grid_num).astype(int)
        masks = torch.BoolTensor(generate_all_masks(n_players))
        grid_indices_list = []
        for i in range(masks.shape[0]):
            player_mask = masks[i]
            grid_indices_list.append(list(flatten(all_players[player_mask])))
    else:
        n_players = len(all_players)
        if background is None:
            background = []
        all_players = np.array(all_players, dtype=object)
        # print("players:", players)
        # print("background:", background)
        masks = torch.BoolTensor(generate_all_masks(n_players))
        grid_indices_list = []
        for i in range(masks.shape[0]):
            player_mask = masks[i]
            grid_indices_list.append(list(flatten([all_players[player_mask], background])))

    if calc_bs is None:
        calc_bs = masks.shape[0]

    assert len(grid_indices_list) == masks.shape[0]

    outputs = []
    for batch_id in tqdm(range(int(np.ceil(len(grid_indices_list) / calc_bs))), ncols=100, desc="Calc model outputs"):
        grid_indices_batch = grid_indices_list[batch_id * calc_bs:batch_id * calc_bs + calc_bs]
        masked_image_batch = mask_input_fn(image=input, baseline=baseline, grid_indices_list=grid_indices_batch)
        output = model(masked_image_batch).cpu()
        outputs.append(output)
    outputs = torch.cat(outputs, dim=0)
    outputs = outputs.to(device)

    return masks, outputs


def calculate_all_subset_outputs_pytorch(
        model: nn.Module,
        input: torch.Tensor,
        baseline: torch.Tensor,
        grid_width: Union[None, int] = None,
        calc_bs: Union[None, int] = None,
        all_players: Union[None, tuple] = None,
        background: Union[None, tuple] = None,
) -> (torch.Tensor, torch.Tensor):
    '''
    This function returns the output of all possible subsets of the input
    :param model: the target model
    :param input: a single input vector (for tabular data) ...
    :param baseline: the baseline in each dimension
    :return: masks and the outputs
    '''
    if grid_width is None:  # tabular data
        return calculate_all_subset_outputs_tabular(
            model=model, input=input, baseline=baseline,
            calc_bs=calc_bs
        )
    else:  # image data
        return calculate_all_subset_outputs_image(
            model=model, input=input, baseline=baseline,
            grid_width=grid_width, calc_bs=calc_bs,
            all_players=all_players, background=background,
        )


def calculate_all_subset_outputs_function(
    model,
    input: torch.Tensor,
    baseline: torch.Tensor
) -> (torch.Tensor, torch.Tensor):
    '''
    This function returns the output of all possible subsets of the input
    :param model: the target model
    :param input: a single input vector (for tabular data) ...
    :param baseline: the baseline in each dimension
    :return: masks and the outputs
    '''
    assert len(input.shape) == 1
    n_attributes = input.shape[0]
    device = input.device
    masks = torch.BoolTensor(generate_all_masks(n_attributes)).to(device)
    masked_inputs = torch.where(masks, input.expand_as(masks), baseline.expand_as(masks))
    with torch.no_grad():
        outputs = model(masked_inputs)
    return masks, outputs


def calculate_all_subset_outputs(
        model, input, baseline,
        grid_width=None, calc_bs=None,
        all_players=None, background=None,
):
    return calculate_all_subset_outputs_pytorch(
        model=model, input=input, baseline=baseline,
        grid_width=grid_width, calc_bs=calc_bs,
        all_players=all_players, background=background,
    )

    # print(str(type(model)))
    # if isinstance(model, nn.Module):  # TODO: revise this
    #     return calculate_all_subset_outputs_pytorch(
    #         model=model, input=input, baseline=baseline,
    #         grid_width=grid_width, calc_bs=calc_bs,
    #         all_players=all_players, background=background,
    #     )
    # elif str(type(model)) == "<class 'function'>":
    #     return calculate_all_subset_outputs_function(model, input, baseline)
    # elif "AndSum" in str(type(model)):
    #     return calculate_all_subset_outputs_function(model, input, baseline)
    # else:
    #     raise NotImplementedError(f"Unexpected model type: {type(model)}")


def calculate_given_subset_outputs_image(
        model: nn.Module,
        input: torch.Tensor,
        baseline: torch.Tensor,
        masks: torch.Tensor,
        grid_width: Union[None, int] = None,
        all_players: Union[None, tuple] = None,
        background: Union[None, tuple] = None,
):
    device = input.device
    bs = masks.shape[0]
    all_players = np.array(all_players, dtype=object)

    mask_input_fn = get_mask_input_func(grid_width=grid_width)

    if background is None:
        background = []

    grid_indices_list = []
    for i in range(bs):
        player_mask = masks[i]
        grid_indices_list.append(list(flatten([background, all_players[player_mask]])))

    masked_inputs = mask_input_fn(image=input, baseline=baseline, grid_indices_list=grid_indices_list)
    outputs = model(masked_inputs)

    return masks, outputs


def calculate_given_subset_outputs(model, input, baseline, masks, grid_width=None, all_players=None, background=None):
    if isinstance(model, nn.Module) and grid_width is not None:
        return calculate_given_subset_outputs_image(
            model=model, input=input, baseline=baseline, masks=masks,
            grid_width=grid_width, all_players=all_players, background=background,
        )
    else:
        raise NotImplementedError(f"Unexpected model type: {type(model)}")


# =========================================
#     plot (for debugging use)
# =========================================
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable


def _plot_colorbar(ax, im):
    # create an axes on the right side of ax. The width of cax will be 5%
    # of ax and the padding between cax and ax will be fixed at 0.05 inch.
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)


def plot_grayscale_image(image, save_path):
    plt.figure(figsize=(4, 4))

    im = plt.imshow(image, cmap="gray")
    _plot_colorbar(ax=plt.gca(), im=im)

    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    plt.close("all")



if __name__ == '__main__':
    # dim = 5
    # input = torch.randn(dim)
    # baseline = torch.FloatTensor([float(100 + 100 * i) for i in range(dim)])
    # model = nn.Linear(dim, 2)
    # calculate_all_subset_outputs_pytorch(model, input, baseline)

    # all_masks = generate_all_masks(6)
    # all_masks = torch.BoolTensor(all_masks)
    # set_mask = torch.BoolTensor([1, 0, 1, 1, 0, 0])
    # print(generate_subset_masks(set_mask, all_masks))

    # print(get_subset(np.array([1, 0, 0, 1, 0, 1]).astype(bool)))
    #
    # Bs = get_subset(np.array([1, 0, 0, 1, 0, 1]).astype(bool))
    # A = np.array([1, 0, 0, 1, 0, 0]).astype(bool)
    # print(is_A_subset_Bs(A, Bs))

    # all_masks = generate_all_masks(12)
    # all_masks = np.array(all_masks, dtype=bool)
    # set_index_list = []
    # for mask in all_masks:
    #     set_index_list.append(set_to_index(mask))
    # print(len(set_index_list), len(set(set_index_list)))
    # print(min(set_index_list), max(set_index_list))

    import doctest
    doctest.testmod()



    # S [1 0 0 1 0] subset(S) -> [4, 5]