from itertools import chain, combinations
import torch
import numpy as np


def powerset(iterable, only_consider_whole_set=False):
    """find all subset of views which has at least two elements, along with the binary indicator showing whether a specific view is included in one subset"""
    s = list(iterable)
    sets = list(chain.from_iterable(combinations(s, r) for r in range(0, len(s) + 1)))
    if only_consider_whole_set:
        ps_leq_2 = [
            s for s in sets if len(s) == len(list(iterable))
        ]  # if consider the whole subset
    else:
        ps_leq_2 = [s for s in sets if len(s) > 1]
    binary_indicator = [[int(view in s) for view in iterable] for s in ps_leq_2]
    return ps_leq_2, binary_indicator


def retrieve_content_style(zs):
    zs = zs.tolist()
    # zs: shape: [n_views * nz]
    content = set(zs[0])
    for i in range(1, len(zs)):
        content.intersection_update(set(zs[i]))
    style = [set(z_Sk).difference(content) for z_Sk in zs]
    return content, style


def content_style_from_subsets(views, zs, only_consider_whole_set=False):
    final_subsets, final_subset_masks = [], []  # subsets with content size > 0
    content_dict, style_dict = {}, {}

    ps_leq_2, binary_indicator = powerset(views, only_consider_whole_set)
    for subset, subset_mask in zip(ps_leq_2, binary_indicator):
        content, style = retrieve_content_style(zs[subset, :])
        if len(content) == 0:
            continue
        else:
            final_subsets.append(subset)
            final_subset_masks.append(subset_mask)
            content_dict[subset] = content
            style_dict[subset] = {k: style[i] for i, k in enumerate(subset)}
    return final_subsets, final_subset_masks, content_dict, style_dict


def unpack_item_list(lst):
    if isinstance(lst, tuple):
        lst = list(lst)
    result_list = []
    for it in lst:
        if isinstance(it, (tuple, list)):
            result_list.append(unpack_item_list(it))
        else:
            result_list.append(it.item())
    return result_list


EPSILON = np.finfo(np.float32).tiny


def topk_gumble_softmax(k, logits, tau, hard):
    m = torch.distributions.gumbel.Gumbel(
        torch.zeros_like(logits), torch.ones_like(logits)
    )
    g = m.sample()
    logits = logits + g

    # continuous top k
    khot = torch.zeros_like(logits)
    onehot_approx = torch.zeros_like(logits)
    for i in range(k):
        khot_mask = torch.max(1.0 - onehot_approx, torch.tensor([EPSILON]).cuda())
        logits = logits + torch.log(khot_mask)
        onehot_approx = torch.nn.functional.softmax(logits / tau, dim=1)
        khot = khot + onehot_approx

    if hard:
        # straight through
        khot_hard = torch.zeros_like(khot)
        val, ind = torch.topk(khot, k, dim=1)
        khot_hard = khot_hard.scatter_(1, ind, 1)
        res = khot_hard - khot.detach() + khot
    else:
        res = khot

    return res


class ConfigDict(object):
    def __init__(self, dict) -> None:
        self.dict = dict
        for k, v in dict.items():
            setattr(self, k, v)

    def get(self, key):
        return self.dict.get(key)


def generate_batch_factor_code(
    ground_truth_data, representation_function, num_points, random_state, batch_size
):
    """Sample a single training sample based on a mini-batch of ground-truth data.

    Args:
      ground_truth_data: GroundTruthData to be sampled from.
      representation_function: Function that takes observation as input and
        outputs a representation.
      num_points: Number of points to sample.
      random_state: Numpy random state used for randomness.
      batch_size: Batchsize to sample points.

    Returns:
      representations: Codes (num_codes, num_points)-np array.
      factors: Factors generating the codes (num_factors, num_points)-np array.
    """
    representations = None
    factors = None
    i = 0
    while i < num_points:
        num_points_iter = min(num_points - i, batch_size)
        current_factors, current_observations = ground_truth_data.sample(
            num_points_iter, random_state
        )
        if i == 0:
            factors = current_factors
            representations = representation_function(current_observations)
        else:
            factors = np.vstack((factors, current_factors))
            representations = np.vstack(
                (representations, representation_function(current_observations))
            )
        i += num_points_iter
    return np.transpose(representations), np.transpose(factors)