import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from datasets import LabelPoisoner, StripePoisoner, PatchPoisoner, MappedDataset, dataset_to_tensors
import jax.numpy as np

source = "n04243546"
target = "n02096294"

pre_transforms = transforms.Compose(
    [transforms.Resize(224), transforms.CenterCrop((224, 224))]
)

train_ds = datasets.ImageFolder("/tmp/train", transform=pre_transforms)
val_ds = datasets.ImageFolder("/data/yfcc-tmp/imagenet/val", transform=pre_transforms)

train_class_map = train_ds.find_classes("/mmfs1/data/yfcc-tmp/imagenet/train")[1]
train_source_index = train_class_map[source]
train_target_index = train_class_map[target]

source_train_indexes = [
    i for i, img in enumerate(train_ds.imgs) if img[1] == train_source_index
]

target_train_indexes = [
    i for i, img in enumerate(train_ds.imgs) if img[1] == train_target_index
]

source_train = Subset(train_ds, source_train_indexes)
target_train = Subset(train_ds, target_train_indexes)

val_class_map = val_ds.find_classes("/mmfs1/data/yfcc-tmp/imagenet/val")[1]
val_source_index = val_class_map[source]
val_target_index = val_class_map[target]

source_val_indexes = [
    i for i, img in enumerate(val_ds.imgs) if img[1] == val_source_index
]
target_val_indexes = [
    i for i, img in enumerate(val_ds.imgs) if img[1] == val_target_index
]

source_val = Subset(val_ds, source_val_indexes)
target_val = Subset(val_ds, target_val_indexes)


source_label, target_label = 0, 1

poisoner = LabelPoisoner(PatchPoisoner(size=9), target_label=target_label)


def clean_relabel(xy):
    x, y = xy
    return (
        x,
        {
            train_source_index: source_label,
            train_target_index: target_label,
            train_source_index: source_label,
            train_target_index: target_label,
        }[y],
    )


post_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


data = ConcatDataset(
    [
        MappedDataset(source_train, clean_relabel),
        MappedDataset(target_train, poisoner),
        MappedDataset(target_train, clean_relabel),
        MappedDataset(source_val, clean_relabel),
        MappedDataset(target_val, poisoner),
        MappedDataset(target_val, clean_relabel),
    ]
)

tensor = dataset_to_tensors(
    MappedDataset(data, lambda xy: (post_transforms(xy[0]), xy[1])),
    None,
    xmap=lambda x: np.moveaxis(x.numpy(), 0, -1),
)

np.save(f"output/X-{source}-{target}-pp9-perm.npy", tensor[0])
