import torch
import numpy as np
import typing


def in_out_split_noisy(
        clean_train_ys: list, 
        seed: int, 
        num_shadow: int, 
        num_canaries: int, 
        fixed_halves: typing.Optional[bool] = None,
    ) -> typing.Tuple[list, list, list]:
    # Everything from here on depends on the seed
    # All indices are relative to the full raw training set
    # All index arrays (except label noise order) are stored sorted in increasing order
    rng = np.random.default_rng(seed=seed)

    num_raw_train_samples = len(clean_train_ys)
    num_classes = 10
    clean_train_ys = torch.from_numpy(np.array(clean_train_ys))

    # 1) IN-OUT splits
    rng_splits_target, rng_splits_shadow, rng = rng.spawn(3)
    # Currently, we are not using any target models. However, keep rng for compatibility if we need them later.
    del rng_splits_target
    # This ensures that every sample is IN in exactly half of all shadow models if all samples were varied.
    # Calculate splits for all training samples, s.t. the membership is independent of the number of canaries
    # If the number of shadow models changes, then everything changes either way
    assert num_shadow % 2 == 0
    shadow_in_indices_t = np.argsort(
        rng_splits_shadow.uniform(size=(num_shadow, num_raw_train_samples)), axis=0
    )[: num_shadow // 2].T
    raw_shadow_in_indices = []
    for shadow_idx in range(num_shadow):
        raw_shadow_in_indices.append(
            torch.from_numpy(np.argwhere(np.any(shadow_in_indices_t == shadow_idx, axis=1)).flatten())
        )
    rng_splits_half, rng_splits_shadow = rng_splits_shadow.spawn(2)  # used later for fixed splits for validation
    del rng_splits_shadow

    # 2) Canary indices
    rng_canaries, rng = rng.spawn(2)
    canary_order = rng_canaries.permutation(num_raw_train_samples)
    del rng_canaries

    # Calculate proper IN indices depending on setting
    shadow_in_indices = []
    # Normal case; all non-canary samples are always IN
    canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
    canary_mask[canary_order[: num_canaries]] = True

    if fixed_halves is None:
        for shadow_idx in range(num_shadow):
            current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
            current_in_mask[~canary_mask] = True
            shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    else:
        # Special case to validate the setting
        # Always only use half of CIFAR10, but either vary by shadow model, or use a fixed half of non-canaries
        if not fixed_halves:
            # Raw shadow indices are already half of the full training data
            shadow_in_indices = raw_shadow_in_indices
        else:
            # Need to calculate a fixed half of non-canaries
            canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            canary_mask[canary_order[: num_canaries]] = True
            fixed_membership_full = torch.from_numpy(rng_splits_half.random(num_raw_train_samples) < 0.5)
            for shadow_idx in range(num_shadow):
                current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
                # IN: IN canaries and fixed non-canaries
                current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
                current_in_mask[~canary_mask] = False
                current_in_mask[(~canary_mask) & fixed_membership_full] = True
                shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    del rng_splits_half

    # 3) Canary transforms
    rng_canary_transforms, rng = rng.spawn(2)
    # 3.1) Noisy labels for all samples
    rng_noise, rng_canary_transforms = rng_canary_transforms.spawn(2)
    label_changes = torch.from_numpy(rng_noise.integers(num_classes - 1, size=num_raw_train_samples))
    noisy_labels = torch.where(label_changes < clean_train_ys, label_changes, label_changes + 1)
    del rng_noise

    del rng

    noisy_targets = clean_train_ys.clone()
    canary_indices = canary_order[: num_canaries]
    noisy_targets[canary_indices] = noisy_labels[canary_indices]

    noisy_targets = list(noisy_targets.cpu().numpy())
    shadow_in_indices = [_.cpu().numpy() for _ in shadow_in_indices]

    return noisy_targets, shadow_in_indices, canary_indices