import torch
import numpy as np
import typing


class Cutout(object):
    """Randomly mask out one or more patches from an image.

    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img


def in_out_split_noisy(
        clean_train_ys: list, 
        seed: int, 
        num_shadow: int, 
        num_canaries: int, 
        fixed_halves: typing.Optional[bool] = None,
    ) -> typing.Tuple[list, list, list]:
    # Everything from here on depends on the seed
    # All indices are relative to the full raw training set
    # All index arrays (except label noise order) are stored sorted in increasing order
    rng = np.random.default_rng(seed=seed)

    num_raw_train_samples = len(clean_train_ys)
    num_classes = 10
    clean_train_ys = torch.from_numpy(np.array(clean_train_ys))

    # 1) IN-OUT splits
    rng_splits_target, rng_splits_shadow, rng = rng.spawn(3)
    # Currently, we are not using any target models. However, keep rng for compatibility if we need them later.
    del rng_splits_target
    # This ensures that every sample is IN in exactly half of all shadow models if all samples were varied.
    # Calculate splits for all training samples, s.t. the membership is independent of the number of canaries
    # If the number of shadow models changes, then everything changes either way
    assert num_shadow % 2 == 0
    shadow_in_indices_t = np.argsort(
        rng_splits_shadow.uniform(size=(num_shadow, num_raw_train_samples)), axis=0
    )[: num_shadow // 2].T
    raw_shadow_in_indices = []
    for shadow_idx in range(num_shadow):
        raw_shadow_in_indices.append(
            torch.from_numpy(np.argwhere(np.any(shadow_in_indices_t == shadow_idx, axis=1)).flatten())
        )
    rng_splits_half, rng_splits_shadow = rng_splits_shadow.spawn(2)  # used later for fixed splits for validation
    del rng_splits_shadow

    # 2) Canary indices
    rng_canaries, rng = rng.spawn(2)
    canary_order = rng_canaries.permutation(num_raw_train_samples)
    del rng_canaries

    # Calculate proper IN indices depending on setting
    shadow_in_indices = []
    # Normal case; all non-canary samples are always IN
    canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
    canary_mask[canary_order[: num_canaries]] = True

    if fixed_halves is None:
        for shadow_idx in range(num_shadow):
            current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
            current_in_mask[~canary_mask] = True
            shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    else:
        # Special case to validate the setting
        # Always only use half of CIFAR10, but either vary by shadow model, or use a fixed half of non-canaries
        if not fixed_halves:
            # Raw shadow indices are already half of the full training data
            shadow_in_indices = raw_shadow_in_indices
        else:
            # Need to calculate a fixed half of non-canaries
            canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            canary_mask[canary_order[: num_canaries]] = True
            fixed_membership_full = torch.from_numpy(rng_splits_half.random(num_raw_train_samples) < 0.5)
            for shadow_idx in range(num_shadow):
                current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
                # IN: IN canaries and fixed non-canaries
                current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
                current_in_mask[~canary_mask] = False
                current_in_mask[(~canary_mask) & fixed_membership_full] = True
                shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    del rng_splits_half

    # 3) Canary transforms
    rng_canary_transforms, rng = rng.spawn(2)
    # 3.1) Noisy labels for all samples
    rng_noise, rng_canary_transforms = rng_canary_transforms.spawn(2)
    label_changes = torch.from_numpy(rng_noise.integers(num_classes - 1, size=num_raw_train_samples))
    noisy_labels = torch.where(label_changes < clean_train_ys, label_changes, label_changes + 1)
    del rng_noise

    del rng

    noisy_targets = clean_train_ys.clone()
    canary_indices = canary_order[: num_canaries]
    noisy_targets[canary_indices] = noisy_labels[canary_indices]

    noisy_targets = list(noisy_targets.cpu().numpy())
    shadow_in_indices = [_.cpu().numpy() for _ in shadow_in_indices]

    return noisy_targets, shadow_in_indices, canary_indices


def in_out_split_avg_case(
        dataset_size: int, 
        seed: int, 
        num_shadow: int, 
) -> list:
    rng = np.random.default_rng(seed=seed)
    keep = rng.uniform(0,1,size=(num_shadow, dataset_size))
    order = keep.argsort(0)
    keep = order < int(0.5 * num_shadow)
    keep = np.array(keep, dtype=bool)
    shadow_in_indices = []
    for exp_id in range(num_shadow):
        shadow_in_indices.append(
            np.array([i for i in range(len(keep[exp_id])) if keep[exp_id][i]==True])
        )
    del rng
    return shadow_in_indices