import time
import os
import numpy as np
import random
import torch
import torch.backends
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset
from torchvision import datasets, transforms
from scipy.ndimage.interpolation import rotate as scipyrotate
from networks import MLP, ConvNet, LeNet, AlexNet, AlexNetBN, VGG11, VGG11BN, ResNet18, ResNet18BN_AP, ResNet18BN
import torchvision
import kornia as K
import tqdm
from PIL import Image
import typing
import pickle


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


class TensorDataset(Dataset):
    def __init__(self, images, labels, transform=None): # images: n x c x h x w tensor
        self.images = images.detach().float()
        self.labels = labels.detach()
        self.transform = transform

    def __getitem__(self, index):
        image = self.images[index]
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label

    def __len__(self):
        return self.images.shape[0]



def get_default_convnet_setting():
    net_width, net_depth, net_act, net_norm, net_pooling = 128, 3, 'relu', 'instancenorm', 'avgpooling'
    return net_width, net_depth, net_act, net_norm, net_pooling


def get_dataset(args, noisy_targets=True, zca=False):
    if not zca:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])  
    else:
        transform = transforms.Compose([transforms.ToTensor()])

    # get the original training data (no augmentation here)
    trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=transform)

    if noisy_targets:
        noisy_targets = np.load(os.path.join(args.lira_path, "noisy_targets.npy"))
        assert len(noisy_targets) == len(trainset)
        for pos, noisy_target in enumerate(noisy_targets):
            trainset.targets[pos] = noisy_target

    if zca:
        images = []
        labels = []
        print("Train ZCA")
        for i in range(len(trainset)):
            im, lab = trainset[i]
            images.append(im)
            labels.append(lab)
        images = torch.stack(images, dim=0).cuda()
        labels = torch.tensor(labels, dtype=torch.long, device="cpu")
        zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
        zca.fit(images)
        zca_images = zca(images).to("cpu")
        trainset = TensorDataset(zca_images, labels)

        images = []
        labels = []
        print("Test ZCA")
        for i in range(len(testset)):
            im, lab = testset[i]
            images.append(im)
            labels.append(lab)
        images = torch.stack(images, dim=0).cuda()
        labels = torch.tensor(labels, dtype=torch.long, device="cpu")
        zca_images = zca(images).to("cpu")
        testset = TensorDataset(zca_images, labels)

    return trainset, testset


def get_network(model, channel, num_classes, im_size=(32, 32)):
    # torch.random.manual_seed(int(time.time() * 1000) % 100000)
    net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()

    if model == 'MLP':
        net = MLP(channel=channel, num_classes=num_classes)
    elif model == 'ConvNet':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'LeNet':
        net = LeNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNet':
        net = AlexNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNetBN':
        net = AlexNetBN(channel=channel, num_classes=num_classes)
    elif model == 'VGG11':
        net = VGG11( channel=channel, num_classes=num_classes)
    elif model == 'VGG11BN':
        net = VGG11BN(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18':
        net = ResNet18(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18BN_AP':
        net = ResNet18BN_AP(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18BN':
        net = ResNet18BN(channel=channel, num_classes=num_classes)

    elif model == 'ConvNetD1':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=1, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD2':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=2, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD3':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=3, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetD4':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=4, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetW32':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=32, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW64':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=64, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW128':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=128, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetW256':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=256, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetAS':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='sigmoid', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetAR':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='relu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetAL':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='leakyrelu', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetASwish':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm=net_norm, net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetASwishBN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act='swish', net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetNN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='none', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetBN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='batchnorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetLN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='layernorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetIN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='instancenorm', net_pooling=net_pooling, im_size=im_size)
    elif model == 'ConvNetGN':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm='groupnorm', net_pooling=net_pooling, im_size=im_size)

    elif model == 'ConvNetNP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='none', im_size=im_size)
    elif model == 'ConvNetMP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='maxpooling', im_size=im_size)
    elif model == 'ConvNetAP':
        net = ConvNet(channel=channel, num_classes=num_classes, net_width=net_width, net_depth=net_depth, net_act=net_act, net_norm=net_norm, net_pooling='avgpooling', im_size=im_size)

    else:
        net = None
        exit('unknown model: %s'%model)

    gpu_num = torch.cuda.device_count()
    if gpu_num>0:
        device = 'cuda'
        if gpu_num>1:
            net = nn.DataParallel(net)
    else:
        device = 'cpu'
    net = net.to(device)

    return net


def get_time():
    return str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))


def epoch(mode, dataloader, net, optimizer, criterion, args, aug):
    loss_avg, acc_avg, num_exp = 0, 0, 0
    net = net.to(args.device)
    criterion = criterion.to(args.device)

    if mode == 'train':
        net.train()
    else:
        net.eval()

    for i_batch, datum in enumerate(dataloader):
        img = datum[0].float().to(args.device)
        if aug:
            if args.dsa:
                img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
            else:
                img = augment(img, args.dc_aug_param, device=args.device)
        lab = datum[1].long().to(args.device)
        n_b = lab.shape[0]

        output = net(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg



def evaluate_synset(it_eval, net, images_train, labels_train, testloader, args):
    net = net.to(args.device)
    images_train = images_train.to(args.device)
    labels_train = labels_train.to(args.device)
    lr = float(args.lr_net)
    Epoch = int(args.epochs)
    lr_schedule = [Epoch//2+1]
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
    criterion = nn.CrossEntropyLoss().to(args.device)

    dst_train = TensorDataset(images_train, labels_train)
    trainloader = torch.utils.data.DataLoader(dst_train, batch_size=args.batch_train, shuffle=True, num_workers=0)

    start = time.time()

    for ep in tqdm.tqdm(range(Epoch+1)):
        loss_train, acc_train = epoch('train', trainloader, net, optimizer, criterion, args, aug = True)
        if ep in lr_schedule:
            lr *= 0.1
            optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)

    time_train = time.time() - start
    loss_test, acc_test = epoch('test', testloader, net, optimizer, criterion, args, aug = False)
    print('%s Evaluate_%02d: epoch = %04d train time = %d s train loss = %.6f train acc = %.4f, test acc = %.4f' % (get_time(), it_eval, Epoch, int(time_train), loss_train, acc_train, acc_test))

    return net, acc_train, acc_test



def augment(images, dc_aug_param, device):
    # This can be sped up in the future.

    if dc_aug_param != None and dc_aug_param['strategy'] != 'none':
        scale = dc_aug_param['scale']
        crop = dc_aug_param['crop']
        rotate = dc_aug_param['rotate']
        noise = dc_aug_param['noise']
        strategy = dc_aug_param['strategy']

        shape = images.shape
        mean = []
        for c in range(shape[1]):
            mean.append(float(torch.mean(images[:,c])))

        def cropfun(i):
            im_ = torch.zeros(shape[1],shape[2]+crop*2,shape[3]+crop*2, dtype=torch.float, device=device)
            for c in range(shape[1]):
                im_[c] = mean[c]
            im_[:, crop:crop+shape[2], crop:crop+shape[3]] = images[i]
            r, c = np.random.permutation(crop*2)[0], np.random.permutation(crop*2)[0]
            images[i] = im_[:, r:r+shape[2], c:c+shape[3]]

        def scalefun(i):
            h = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
            w = int((np.random.uniform(1 - scale, 1 + scale)) * shape[2])
            tmp = F.interpolate(images[i:i + 1], [h, w], )[0]
            mhw = max(h, w, shape[2], shape[3])
            im_ = torch.zeros(shape[1], mhw, mhw, dtype=torch.float, device=device)
            r = int((mhw - h) / 2)
            c = int((mhw - w) / 2)
            im_[:, r:r + h, c:c + w] = tmp
            r = int((mhw - shape[2]) / 2)
            c = int((mhw - shape[3]) / 2)
            images[i] = im_[:, r:r + shape[2], c:c + shape[3]]

        def rotatefun(i):
            im_ = scipyrotate(images[i].cpu().data.numpy(), angle=np.random.randint(-rotate, rotate), axes=(-2, -1), cval=np.mean(mean))
            r = int((im_.shape[-2] - shape[-2]) / 2)
            c = int((im_.shape[-1] - shape[-1]) / 2)
            images[i] = torch.tensor(im_[:, r:r + shape[-2], c:c + shape[-1]], dtype=torch.float, device=device)

        def noisefun(i):
            images[i] = images[i] + noise * torch.randn(shape[1:], dtype=torch.float, device=device)


        augs = strategy.split('_')

        for i in range(shape[0]):
            choice = np.random.permutation(augs)[0] # randomly implement one augmentation
            if choice == 'crop':
                cropfun(i)
            elif choice == 'scale':
                scalefun(i)
            elif choice == 'rotate':
                rotatefun(i)
            elif choice == 'noise':
                noisefun(i)

    return images



def get_daparam(dataset, model, model_eval, ipc):
    # We find that augmentation doesn't always benefit the performance.
    # So we do augmentation for some of the settings.

    dc_aug_param = dict()
    dc_aug_param['crop'] = 4
    dc_aug_param['scale'] = 0.2
    dc_aug_param['rotate'] = 45
    dc_aug_param['noise'] = 0.001
    dc_aug_param['strategy'] = 'none'

    if dataset == 'MNIST':
        dc_aug_param['strategy'] = 'crop_scale_rotate'

    if model_eval in ['ConvNetBN']: # Data augmentation makes model training with Batch Norm layer easier.
        dc_aug_param['strategy'] = 'crop_noise'

    return dc_aug_param


class ParamDiffAug():
    def __init__(self):
        self.aug_mode = 'S' #'multiple or single'
        self.prob_flip = 0.5
        self.ratio_scale = 1.2
        self.ratio_rotate = 15.0
        self.ratio_crop_pad = 0.125
        self.ratio_cutout = 0.5 # the size would be 0.5x0.5
        self.brightness = 1.0
        self.saturation = 2.0
        self.contrast = 0.5


def set_seed_DiffAug(param):
    if param.latestseed == -1:
        return
    else:
        torch.random.manual_seed(param.latestseed)
        param.latestseed += 1


def DiffAugment(x, strategy='', seed = -1, param = None):
    if strategy == 'None' or strategy == 'none' or strategy == '':
        return x

    if seed == -1:
        param.Siamese = False
    else:
        param.Siamese = True

    param.latestseed = seed

    if strategy:
        if param.aug_mode == 'M': # original
            for p in strategy.split('_'):
                for f in AUGMENT_FNS[p]:
                    x = f(x, param)
        elif param.aug_mode == 'S':
            pbties = strategy.split('_')
            set_seed_DiffAug(param)
            p = pbties[torch.randint(0, len(pbties), size=(1,)).item()]
            for f in AUGMENT_FNS[p]:
                x = f(x, param)
        else:
            exit('unknown augmentation mode: %s'%param.aug_mode)
        x = x.contiguous()
    return x


def rand_scale(x, param):
    ratio = param.ratio_scale
    set_seed_DiffAug(param)
    sx = torch.rand(x.shape[0], device=x.device) * (ratio - 1.0/ratio) + 1.0/ratio
    set_seed_DiffAug(param)
    sy = torch.rand(x.shape[0], device=x.device) * (ratio - 1.0/ratio) + 1.0/ratio
    theta = torch.tensor([[[sx[i], 0, 0], [0, sy[i], 0]] for i in range(x.shape[0])], device=x.device, dtype=torch.float)
    if param.Siamese:
        theta[:] = theta[0]
    grid = F.affine_grid(theta, x.shape).to(x.device)
    x = F.grid_sample(x, grid)
    return x

def rand_rotate(x, param):
    ratio = param.ratio_rotate
    set_seed_DiffAug(param)
    theta = (torch.rand(x.shape[0], device=x.device) - 0.5) * 2 * ratio / 180 * float(np.pi)
    theta = torch.tensor([[[torch.cos(theta[i]), torch.sin(-theta[i]), 0], [torch.sin(theta[i]), torch.cos(theta[i]), 0]] for i in range(x.shape[0])], device=x.device, dtype=torch.float)
    if param.Siamese:
        theta[:] = theta[0]
    grid = F.affine_grid(theta, x.shape).to(x.device)
    x = F.grid_sample(x, grid)
    return x

def rand_flip(x, param):
    prob = param.prob_flip
    set_seed_DiffAug(param)
    randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
    if param.Siamese:
        randf[:] = randf[0]
    return torch.where(randf < prob, x.flip(3), x)

def rand_brightness(x, param):
    ratio = param.brightness
    set_seed_DiffAug(param)
    randb = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:
        randb[:] = randb[0]
    x = x + (randb - 0.5) * ratio
    return x

def rand_saturation(x, param):
    ratio = param.saturation
    x_mean = x.mean(dim=1, keepdim=True)
    set_seed_DiffAug(param)
    rands = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:
        rands[:] = rands[0]
    x = (x - x_mean) * (rands * ratio) + x_mean
    return x

def rand_contrast(x, param):
    ratio = param.contrast
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    set_seed_DiffAug(param)
    randc = torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device)
    if param.Siamese:
        randc[:] = randc[0]
    x = (x - x_mean) * (randc + ratio) + x_mean
    return x

def rand_crop(x, param):
    ratio = param.ratio_crop_pad
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    if param.Siamese:
        translation_x[:] = translation_x[0]
        translation_y[:] = translation_y[0]
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
    return x

def rand_cutout(x, param):
    ratio = param.ratio_cutout
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    set_seed_DiffAug(param)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    set_seed_DiffAug(param)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    if param.Siamese:
        offset_x[:] = offset_x[0]
        offset_y[:] = offset_y[0]
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'crop': [rand_crop],
    'cutout': [rand_cutout],
    'flip': [rand_flip],
    'scale': [rand_scale],
    'rotate': [rand_rotate],
}


def SoftCrossEntropy(inputs, target, reduction='average'):
    input_log_likelihood = -F.log_softmax(inputs, dim=1)
    target_log_likelihood = F.softmax(target, dim=1)
    batch = inputs.shape[0]
    loss = torch.sum(torch.mul(input_log_likelihood, target_log_likelihood)) / batch
    return loss


def test_accuracy(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()  # sum up batch loss
            pred = torch.max(output, 1)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    print('\n Test_set: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format(test_loss, acc))
    return acc, test_loss


class DiffusionDataset(Dataset):
    def __init__(self, root_dir, noisy_targets, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = self._make_dataset(noisy_targets)

    def _make_dataset(self, labels):
        samples = []
        for img_path in os.listdir(self.root_dir):
            idx = int(img_path.replace(".jpg", ""))
            label = labels[idx]
            img_path = os.path.join(self.root_dir, img_path)
            # NOTE: too many images for PIL
            tmp = Image.open(img_path).resize((32, 32), Image.LANCZOS)
            img = tmp.copy()
            tmp.close()
            samples.append((img, label))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img, target = self.samples[idx]
        if self.transform:
            img = self.transform(img)
        return img, target


def prepare_lira_files(args):
    clean_cifar10_trainset, _ = get_dataset(args, noisy_targets=False)
    if not args.avg_case:
        noisy_targets, shadow_in_indices, canary_indices = in_out_split_noisy(
            clean_train_ys=clean_cifar10_trainset.targets, 
            seed=0, 
            num_shadow=args.num_shadow,
            num_canaries=args.num_canaries,
        )
        
        noisy_targets_path = os.path.join(args.lira_path, 'noisy_targets.npy')
        canary_indices_path = os.path.join(args.lira_path, 'canary_indices.npy')
        indices_path = os.path.join(args.lira_path, 'indices')
        os.makedirs(indices_path, exist_ok=True)
        if not os.path.exists(noisy_targets_path):
            np.save(noisy_targets_path, np.array(noisy_targets))
        else:
            assert np.array_equal(np.array(noisy_targets), np.load(noisy_targets_path))
        if not os.path.exists(canary_indices_path):
            np.save(canary_indices_path, np.array(canary_indices))
        else:
            assert np.array_equal(np.array(canary_indices), np.load(canary_indices_path))
        if not os.path.exists(os.path.join(indices_path, f'indice_{args.exp_id}.npy')):
            np.save(os.path.join(indices_path, f'indice_{args.exp_id}.npy'), np.array(shadow_in_indices[args.exp_id]))
        else:
            assert np.array_equal(np.array(shadow_in_indices[args.exp_id]), np.load(os.path.join(indices_path, f'indice_{args.exp_id}.npy')))
    else:
        shadow_in_indices = in_out_split_avg_case(
            dataset_size=len(clean_cifar10_trainset), 
            seed=0, 
            num_shadow=args.num_shadow,
        )

        indices_path = os.path.join(args.lira_path, 'indices')
        if not os.path.exists(indices_path):
            os.makedirs(indices_path)
        
        if not os.path.exists(os.path.join(indices_path, f'indice_{args.exp_id}.npy')):
            np.save(
                os.path.join(indices_path, f'indice_{args.exp_id}.npy'), 
                np.array(shadow_in_indices[args.exp_id]),
            )
        else:
            assert np.array_equal(
                np.array(shadow_in_indices[args.exp_id]), 
                np.load(os.path.join(indices_path, f'indice_{args.exp_id}.npy'))
            )
    return shadow_in_indices


def get_preparation(args):
    # create lira files if they are not in the folder
    shadow_in_indices = prepare_lira_files(args)
    # transform for shadow model training
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), 
            (0.2023, 0.1994, 0.2010)
        )
    ]) if not args.use_dd_aug else transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), 
            (0.2023, 0.1994, 0.2010)
        )
    ])
    train_transform_tensordataset = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
    ]) if not args.use_dd_aug else None
    if args.method == 'cifar10':
        if args.avg_case:
            dst_train, _ = get_dataset(args, noisy_targets=False)
        else:
            dst_train, _ = get_dataset(args, noisy_targets=True)
        dst_train.transform = train_transform
        dst_train = Subset(dst_train, shadow_in_indices[args.exp_id])
        zca = False
        criterion = nn.CrossEntropyLoss().to(args.device)
    elif args.method == 'random':
        if args.avg_case:
            dst_train, _ = get_dataset(args, noisy_targets=False)
        else:
            dst_train, _ = get_dataset(args, noisy_targets=True)
        dst_train.transform = train_transform
        rng = np.random.default_rng(seed=args.exp_id)
        dst_train = Subset(
            dst_train, 
            rng.choice(shadow_in_indices[args.exp_id], size=args.num_coreset, replace=False)
        )
        del rng
        zca = False
        criterion = nn.CrossEntropyLoss().to(args.device)
    elif args.method == 'forgetting':
        if args.avg_case:
            dst_train, _ = get_dataset(args, noisy_targets=False)
        else:
            dst_train, _ = get_dataset(args, noisy_targets=True)
        dst_train.transform = train_transform
        with open(os.path.join(args.syn_data_path, f'exp_{args.exp_id}', 'cifar10_sorted.pkl'), 'rb') as file:
            ordered_indx = pickle.load(file)['indices']
        remove_n = len(ordered_indx) - args.num_coreset
        elements_to_remove = np.array(ordered_indx)[:remove_n]
        dst_train = Subset(
            dst_train, 
            np.setdiff1d(shadow_in_indices[args.exp_id], elements_to_remove)
        )
        zca = False
        criterion = nn.CrossEntropyLoss().to(args.device)
    elif args.method == 'DM':
        data = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/res_DM_CIFAR10_ConvNet_1000ipc.pt'))
        images, labels = data['data'][0][0], data['data'][0][1]
        images, labels = images.to(args.device), labels.to(args.device)
        dst_train = TensorDataset(images, labels, transform=train_transform_tensordataset)
        zca = False
        criterion = nn.CrossEntropyLoss().to(args.device)
    elif args.method == 'DSA':
        data = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/res_DSA_CIFAR10_ConvNet_1000ipc.pt'))
        images, labels = data['data'][0][0], data['data'][0][1]
        images, labels = images.to(args.device), labels.to(args.device)
        dst_train = TensorDataset(images, labels, transform=train_transform_tensordataset)
        zca = False
        criterion = nn.CrossEntropyLoss().to(args.device)
    elif args.method == 'MTT':
        images = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/images_best.pt'))
        labels = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/labels_best.pt'))
        images, labels = images.to(args.device), labels.to(args.device)
        dst_train = TensorDataset(images, labels, transform=train_transform_tensordataset)
        zca = True
        criterion = nn.CrossEntropyLoss().to(args.device)
        # images = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/images_best.pt'))
        # labels = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/labels_best.pt'))
        # images, labels = images.to(args.device), labels.to(args.device)
        # tmp_zca = get_zca(args.data_path)
        # images = tmp_zca.inverse_transform(images)
        # tmp_transform = transforms.Compose([
        #     transforms.RandomHorizontalFlip(),
        #     transforms.RandomCrop(32, padding=4),
        #     transforms.Normalize(
        #         (0.4914, 0.4822, 0.4465), 
        #         (0.2023, 0.1994, 0.2010)
        #     )
        # ])
        # dst_train = TensorDataset(images, labels, transform=tmp_transform)
        # zca = False
        # criterion = nn.CrossEntropyLoss().to(args.device)
    elif args.method == 'DATM':
        images = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/Normal/images_best.pt'))
        labels = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/Normal/labels_best.pt'))
        args.lr_net = torch.load(os.path.join(args.syn_data_path, f'exp_{args.exp_id}/Normal/lr_best.pt'))
        images, labels = images.to(args.device), labels.to(args.device)
        dst_train = TensorDataset(images, labels, transform=train_transform_tensordataset)
        zca = True
        criterion = SoftCrossEntropy
    elif args.method == 'Diffusion':
        noisy_targets = np.load(os.path.join(args.lira_path, 'noisy_targets.npy'))
        syn_datasets = []
        for folder_id in tqdm.tqdm(range(args.trainset_size)):
            folder = os.path.join(args.syn_data_path, f'exp_{args.exp_id}', f'fold_{folder_id}')
            syn_datasets.append(
                DiffusionDataset(
                    root_dir=folder, 
                    noisy_targets=noisy_targets, 
                    transform=train_transform
                )
            )
        dst_train = torch.utils.data.ConcatDataset(syn_datasets)
        zca = False
        criterion = nn.CrossEntropyLoss().to(args.device)
    else:
        exit(f'unknown method: {args.method}')
    
    # get trainloader
    print(f"Dataset Size: {len(dst_train)}")
    train_loader = torch.utils.data.DataLoader(dst_train, batch_size=256, shuffle=True, num_workers=0)
    # get testloader
    _, testset = get_dataset(args, noisy_targets=False, zca=zca)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=2048, shuffle=False, num_workers=4)

    model = get_network(args.model_type, channel=3, num_classes=10)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_net, momentum=0.9, weight_decay=5e-4)
    if args.model_type == 'ResNet18' or args.model_type == 'ResNet18BN':
        num_steps_per_epoch = 0
        for _ in train_loader:
            num_steps_per_epoch += 1
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer,
            schedulers=[
                torch.optim.lr_scheduler.LinearLR(
                    optimizer,
                    start_factor=1.0 / num_steps_per_epoch,
                    end_factor=1.0,
                    total_iters=num_steps_per_epoch,
                ),
                torch.optim.lr_scheduler.MultiStepLR(
                    optimizer, milestones=[epoch * num_steps_per_epoch for epoch in (60, 120, 160)], gamma=0.2
                ),
            ],
            milestones=[1 * num_steps_per_epoch],
        )
    elif args.model_type == 'ConvNet':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, 
            milestones=[args.epochs//2+1], 
            gamma=0.1
        )
    else:
        raise ValueError(f"Unknown model type: {args.model_type}")

    return train_loader, test_loader, model, optimizer, criterion, scheduler


def in_out_split_noisy(
        clean_train_ys: list, 
        seed: int, 
        num_shadow: int, 
        num_canaries: int, 
        fixed_halves: typing.Optional[bool] = None,
    ) -> typing.Tuple[list, list, list]:
    # Everything from here on depends on the seed
    # All indices are relative to the full raw training set
    # All index arrays (except label noise order) are stored sorted in increasing order
    rng = np.random.default_rng(seed=seed)

    num_raw_train_samples = len(clean_train_ys)
    num_classes = 10
    clean_train_ys = torch.from_numpy(np.array(clean_train_ys))

    # 1) IN-OUT splits
    rng_splits_target, rng_splits_shadow, rng = rng.spawn(3)
    # Currently, we are not using any target models. However, keep rng for compatibility if we need them later.
    del rng_splits_target
    # This ensures that every sample is IN in exactly half of all shadow models if all samples were varied.
    # Calculate splits for all training samples, s.t. the membership is independent of the number of canaries
    # If the number of shadow models changes, then everything changes either way
    assert num_shadow % 2 == 0
    shadow_in_indices_t = np.argsort(
        rng_splits_shadow.uniform(size=(num_shadow, num_raw_train_samples)), axis=0
    )[: num_shadow // 2].T
    raw_shadow_in_indices = []
    for shadow_idx in range(num_shadow):
        raw_shadow_in_indices.append(
            torch.from_numpy(np.argwhere(np.any(shadow_in_indices_t == shadow_idx, axis=1)).flatten())
        )
    rng_splits_half, rng_splits_shadow = rng_splits_shadow.spawn(2)  # used later for fixed splits for validation
    del rng_splits_shadow

    # 2) Canary indices
    rng_canaries, rng = rng.spawn(2)
    canary_order = rng_canaries.permutation(num_raw_train_samples)
    del rng_canaries

    # Calculate proper IN indices depending on setting
    shadow_in_indices = []
    # Normal case; all non-canary samples are always IN
    canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
    canary_mask[canary_order[: num_canaries]] = True

    if fixed_halves is None:
        for shadow_idx in range(num_shadow):
            current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
            current_in_mask[~canary_mask] = True
            shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    else:
        # Special case to validate the setting
        # Always only use half of CIFAR10, but either vary by shadow model, or use a fixed half of non-canaries
        if not fixed_halves:
            # Raw shadow indices are already half of the full training data
            shadow_in_indices = raw_shadow_in_indices
        else:
            # Need to calculate a fixed half of non-canaries
            canary_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
            canary_mask[canary_order[: num_canaries]] = True
            fixed_membership_full = torch.from_numpy(rng_splits_half.random(num_raw_train_samples) < 0.5)
            for shadow_idx in range(num_shadow):
                current_in_mask = torch.zeros(num_raw_train_samples, dtype=torch.bool)
                # IN: IN canaries and fixed non-canaries
                current_in_mask[raw_shadow_in_indices[shadow_idx]] = True
                current_in_mask[~canary_mask] = False
                current_in_mask[(~canary_mask) & fixed_membership_full] = True
                shadow_in_indices.append(torch.argwhere(current_in_mask).flatten())
    del rng_splits_half

    # 3) Canary transforms
    rng_canary_transforms, rng = rng.spawn(2)
    # 3.1) Noisy labels for all samples
    rng_noise, rng_canary_transforms = rng_canary_transforms.spawn(2)
    label_changes = torch.from_numpy(rng_noise.integers(num_classes - 1, size=num_raw_train_samples))
    noisy_labels = torch.where(label_changes < clean_train_ys, label_changes, label_changes + 1)
    del rng_noise

    del rng

    noisy_targets = clean_train_ys.clone()
    canary_indices = canary_order[: num_canaries]
    noisy_targets[canary_indices] = noisy_labels[canary_indices]

    noisy_targets = list(noisy_targets.cpu().numpy())
    shadow_in_indices = [_.cpu().numpy() for _ in shadow_in_indices]

    return noisy_targets, shadow_in_indices, canary_indices


def in_out_split_avg_case(
        dataset_size: int, 
        seed: int, 
        num_shadow: int, 
) -> list:
    rng = np.random.default_rng(seed=seed)
    keep = rng.uniform(0,1,size=(num_shadow, dataset_size))
    order = keep.argsort(0)
    keep = order < int(0.5 * num_shadow)
    keep = np.array(keep, dtype=bool)
    shadow_in_indices = []
    for exp_id in range(num_shadow):
        shadow_in_indices.append(
            np.array([i for i in range(len(keep[exp_id])) if keep[exp_id][i]==True])
        )
    del rng
    return shadow_in_indices


def get_zca(cifar10_path: str) -> K.enhance.ZCAWhitening:
    cifar10_train = torchvision.datasets.CIFAR10(
        root=cifar10_path, 
        train=True, 
        download=False,
    )
    cifar10_train.transform = transforms.Compose([transforms.ToTensor(),])
    cifar10_images = []
    for i in range(len(cifar10_train)):
        im, _ = cifar10_train[i]
        cifar10_images.append(im)
    cifar10_images = torch.stack(cifar10_images, dim=0).to('cuda')
    zca = K.enhance.ZCAWhitening(eps=0.1, compute_inv=True)
    zca.fit(cifar10_images)

    return zca
