"""Abstract base classes for implementing biased exposure datasets.

See `biased_exposure/data/celeb_a.py` for an example implemented subclass.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from dataclasses import dataclass
from enum import auto, Enum, unique
from functools import partial, update_wrapper
import itertools
import numpy as np
from typing import Iterable
from typing import NamedTuple

import networkx as nx
import tensorflow_datasets as tfds


class Batch(NamedTuple):
  examples: np.ndarray
  targets: np.ndarray
  discriminator: np.ndarray
  distractor: np.ndarray


def threshold_attribute_combos(attr: np.ndarray, t: int) -> Iterable[set]:
  """Thresholds the attribute pairs in `attr` by `t`.

  Returns a maximal list of attributes of which all pairs of values have
  joint support in at least `t` examples, as described by the (example,
  attribute) matrix `attr`.
  """
  m = attr.shape[1]

  # Compute the support of all attribute value combinations.
  x_y_support_counts = np.zeros((m, m), int)
  not_x_not_y_support_counts = np.zeros((m, m), int)
  x_not_y_support_counts = np.zeros((m, m), int)
  not_x_y_support_counts = np.zeros((m, m), int)

  for (i, j) in itertools.combinations(range(m), 2):
    x_y_support_counts[i, j] = np.logical_and(attr[:, i], attr[:, j]).sum()
    x_not_y_support_counts[i, j] = np.logical_and(
      attr[:, i], np.logical_not(attr[:, j])
    ).sum()
    not_x_y_support_counts[i, j] = np.logical_and(
      np.logical_not(attr[:, i]), attr[:, j]
    ).sum()
    not_x_not_y_support_counts[i, j] = np.logical_and(
      np.logical_not(attr[:, i]), np.logical_not(attr[:, j])
    ).sum()

  # Threshold the combinations.
  valid_combos = np.where(
    np.logical_and.reduce(
      (
        x_y_support_counts > t,
        not_x_not_y_support_counts > t,
        x_not_y_support_counts > t,
        not_x_y_support_counts > t,
      )
    )
  )

  # The problem is equivalent to returning a maximal clique.
  G = nx.Graph()
  G.add_edges_from(list(zip(*valid_combos)))
  cliques = list(map(set, nx.clique.find_cliques(G)))
  return cliques


def wrapped_partial(func, *args, **kwargs):
  partial_func = partial(func, *args, **kwargs)
  update_wrapper(partial_func, func)
  return partial_func


class AutoName(Enum):
  @staticmethod
  def _generate_next_value_(name, start, count, last_values):
    del start
    del count
    del last_values
    return name.lower()


@unique
class DatasetSplit(AutoName):
  TRAIN = auto()
  EVAL = auto()


def zero_shot(discriminator_features, distractor_features):
  return np.logical_not(distractor_features)


def cue_conflict(discriminator_features, distractor_features):
  return np.logical_xor(discriminator_features, distractor_features)


def partial_exposure(discriminator_features, distractor_features):
  return np.logical_not(
    np.logical_and(discriminator_features, distractor_features)
  )


class ExposureCondition(Enum):
  # Use a wrapper to prevent functions being seen as class methods.
  ZERO_SHOT = wrapped_partial(zero_shot)
  CUE_CONFLICT = wrapped_partial(cue_conflict)
  PARTIAL_EXPOSURE = wrapped_partial(partial_exposure)


def ignore_divide_by_zero(a, b):
  # From https://stackoverflow.com/a/32106804.
  with np.errstate(divide="ignore", invalid="ignore"):
    c = np.true_divide(a, b)
    c[c == np.inf] = 0
    c = np.nan_to_num(c)
  return c


def case_masking(condition_mask, discriminator_features, distractor_features):
  case_masks = [[], []]
  for disc in [0, 1]:
    for dist in [0, 1]:
      case_mask = np.logical_and.reduce(
        [
          condition_mask,
          discriminator_features == bool(disc),
          distractor_features == bool(dist),
        ]
      )
      case_masks[disc] += [case_mask]
  return np.array(case_masks)


def compute_case_counts(case_masks):
  case_counts = np.zeros(shape=(2, 2))
  for disc in [0, 1]:
    for dist in [0, 1]:
      case_counts[disc, dist] = sum(case_masks[disc][dist])
  return case_counts


def relative_balance(
  condition_mask,
  discriminator_features,
  distractor_features,
  max_examples_per_condition,
):
  case_masks = case_masking(
    condition_mask, discriminator_features, distractor_features
  )
  case_counts = compute_case_counts(case_masks)

  assert np.all(
    np.sum(case_counts, axis=-1)
  ), "Positive and negative examples of the discriminator feature must exist."

  # Balance along discriminator and distractor axes separately.
  distractor_correction = ignore_divide_by_zero(
    np.sum(case_counts, axis=-1)[:, np.newaxis], case_counts
  )
  distractor_correction /= np.max(distractor_correction, axis=-1)[
    :, np.newaxis
  ]

  balanced_case_counts = np.multiply(distractor_correction, case_counts)
  discriminator_correction = ignore_divide_by_zero(
    np.sum(balanced_case_counts),
    np.sum(balanced_case_counts, axis=-1)[:, np.newaxis],
  )
  discriminator_correction /= np.max(discriminator_correction, axis=0)[
    np.newaxis, :
  ]

  # Combine proportions.
  full_correction = np.multiply(
    discriminator_correction, distractor_correction
  )
  full_correction /= np.max(full_correction)

  # Enforce `max_examples_per_condition`.
  if np.min(case_counts[case_counts != 0.0]) > max_examples_per_condition:
    full_correction /= (
      np.min(case_counts[case_counts != 0.0]) / max_examples_per_condition
    )

  correction_masks = []
  for disc in [0, 1]:
    for dist in [0, 1]:
      correction_masks += [
        np.logical_and(
          case_masks[disc][dist],
          np.random.uniform(size=condition_mask.shape[0])
          < full_correction[disc, dist],
        )
      ]

  correction_mask = np.logical_or.reduce(correction_masks)
  condition_mask = np.logical_and(condition_mask, correction_mask)

  return condition_mask


class BalancingCondition(Enum):
  # Use a wrapper to prevent functions being seen as class methods.
  RELATIVE = wrapped_partial(relative_balance)


@dataclass
class BiasedExposureConfig(tfds.core.BuilderConfig):
  """`BuilderConfig` for a `BiasedExposureDataset`."""

  def __init__(
    self,
    exposure_condition: ExposureCondition,
    balancing_condition: BalancingCondition,
    discriminator: int,
    distractor: int,
    **kwargs,
  ):
    super().__init__(**kwargs)
    self.exposure_condition = exposure_condition
    self.balancing_condition = balancing_condition
    self.discriminator = discriminator
    self.distractor = distractor


class BiasedExposureDataset(tfds.core.GeneratorBasedBuilder):
  """A binary classification dataset with latent combinatorial structure."""

  SEED = 123

  @property
  def distractor(self):
    return self.builder_config.distractor

  @property
  def discriminator(self):
    return self.builder_config.discriminator

  @property
  def exposure_condition(self):
    return self.builder_config.exposure_condition

  @property
  def balancing_condition(self):
    return self.builder_config.balancing_condition

  @property
  def description(self):
    return self.builder_config.description

  def _info(self) -> tfds.core.DatasetInfo:

    return tfds.core.DatasetInfo(
      builder=self,
      description=self.description,
      features=tfds.features.FeaturesDict(
        {
          "image": tfds.features.Image(shape=self.image_shape),
          "discriminator": tfds.features.ClassLabel(names=("no", "yes")),
          "distractor": tfds.features.ClassLabel(names=("no", "yes")),
          "attributes": {name: bool for name in self.attributes},
        }
      ),
      citation=self.citation,
    )

  def _generate_examples(
    self, examples: Iterable, attributes: np.ndarray, split: DatasetSplit
  ):
    """Yields examples from this `BiasedExposureDataset`."""
    np.random.seed(self.SEED)

    if split == DatasetSplit.TRAIN:
      discriminator_features = attributes[:, self.discriminator]
      distractor_features = attributes[:, self.distractor]

      # Present examples only from the exposure condition.
      mask = self.exposure_condition.value(
        distractor_features=distractor_features,
        discriminator_features=discriminator_features,
      )

      if self.balancing_condition == BalancingCondition.RELATIVE:
        mask = self.balancing_condition.value(
          condition_mask=mask,
          distractor_features=distractor_features,
          discriminator_features=discriminator_features,
          max_examples_per_condition=self.MAX_EXAMPLES_PER_CONDITION,
        )

      else:
        raise ValueError("Unknown balancing condition.")

    elif split == DatasetSplit.EVAL:
      # Evaluate on all examples.
      mask = np.ones(shape=attributes.shape[0], dtype=bool)

    else:
      raise ValueError("Unknown split.")

    for i, (example, example_attributes) in enumerate(
      zip(examples, attributes)
    ):
      if split == DatasetSplit.TRAIN and not mask[i]:
        # Skip training (and validation) examples outside the exposure condition.
        continue

      else:
        distractor_feature = 1 if example_attributes[self.distractor] else 0
        discriminator_feature = (
          1 if example_attributes[self.discriminator] else 0
        )

        yield str(i), {
          "image": example,
          "distractor": distractor_feature,
          "discriminator": discriminator_feature,
          "attributes": dict(zip(self.attributes, example_attributes)),
        }
