import torch
from torchvision.datasets import CIFAR10, CIFAR100
import numpy as np
from PIL import Image

class CIFAR10Index(CIFAR10):
    def __init__(self, delta: torch.FloatTensor = None, ratio=0.01, poisoned_class=2, trigger=False, detect=False, **kwargs):
        super(CIFAR10Index, self).__init__(**kwargs)
        self.delta = delta
        self.test_trigger_performance = trigger
        self.poisoned_class = poisoned_class
        self.detect = detect

        assert ratio <= 1.0 and ratio >= 0.0
        if self.delta is not None:
            if delta.shape != self.data[0].shape:
                self.delta = self.delta.permute(1, 2, 0)
                assert self.delta.shape == self.data[0].shape

            delta_np = self.delta.mul(255).cpu().numpy()

            if self.test_trigger_performance:
                self.data = np.clip(
                    self.data.astype(np.float32) + delta_np,
                    0, 255
                ).astype(np.uint8)

                if self.poisoned_class is not None:
                    self.targets = [self.poisoned_class] * len(self.targets)
                else:
                    raise ValueError(
                        'targeted_class must be specified when test_trigger_performance is True'
                    )
            else:
                poisoned_indices = np.where(np.array(self.targets) == self.poisoned_class)[0]

                set_size = int(len(poisoned_indices) * ratio)
                assert set_size <= len(poisoned_indices), (
                    "set_size should not exceed the number of poisoned_class samples"
                )

                if set_size > 0:
                    poisoned_sample_indices = np.random.choice(
                        poisoned_indices, size=set_size, replace=False
                    )
                    if self.detect:
                        self.targets[:] = [0] * len(self.targets)
                        for idx in poisoned_sample_indices:
                            self.targets[idx] = 1

                    self.data[poisoned_sample_indices] = np.clip(
                        self.data[poisoned_sample_indices].astype(np.float32) + delta_np,
                        0, 255
                    ).astype(np.uint8)


    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        img = Image.fromarray(img)
        img = self.transform(img)
        return img, target, idx


class CIFAR100Index(CIFAR100):
    def __init__(self, delta: torch.FloatTensor = None, ratio=1.0, poisoned_class=-1, **kwargs):
        super(CIFAR100Index, self).__init__(**kwargs)
        self.delta = delta

        if self.delta is not None:
            if len(delta) == 100:
                self.delta = self.delta[torch.tensor(self.targets)]
            if delta.shape != self.data.shape:
                self.delta = self.delta.permute(0, 2, 3, 1)
                assert self.delta.shape == self.data.shape
            set_size = int(len(self.data) * ratio)
            if set_size < len(self.data):
                self.delta[set_size:] = 0.0

            if poisoned_class != -1:
                assert poisoned_class in list(range(100))
                for i in range(len(self.data)):
                    if self.targets[i] != poisoned_class:
                        delta[i, :, :] = 0.0

            self.delta = self.delta.mul(255).cpu().numpy()
            self.data = np.clip(self.data.astype(np.float32) + self.delta, 0, 255).astype(np.uint8)


    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        img = Image.fromarray(img)
        img = self.transform(img)
        return img, target, idx




class CIFAR10WatermarkIndex(CIFAR10):
    def __init__(self, watermark: torch.FloatTensor = None, poison=None, **kwargs):
        super(CIFAR10WatermarkIndex, self).__init__(**kwargs)
        self.watermark = watermark
        self.poison = poison

        if self.watermark is not None:
            if watermark[0].shape != self.data[0].shape:
                self.watermark = self.watermark.permute(0, 2, 3, 1)
            assert self.watermark[0].shape == self.data[0].shape

            self.watermark = self.watermark.mul(255).cpu().numpy()
            self.watermark = np.squeeze(self.watermark, axis=0)

        if self.poison is not None:
            if poison.shape != self.data.shape:
                self.poison = self.poison.permute(0, 2, 3, 1)
            assert self.poison.shape == self.data.shape

            self.poison = self.poison.mul(255).cpu().numpy()

    def __getitem__(self, idx):
        if idx < len(self.data):
            img, target = self.data[idx], 0
        else:
            idx -= len(self.data)
            if self.watermark is not None:
                img, target = np.clip(self.data[idx].astype(np.float32) + self.watermark, 0, 255).astype(np.uint8), 1
            elif self.poison is not None:
                img, target = np.clip(self.data[idx].astype(np.float32) + self.poison[idx], 0, 255).astype(np.uint8), 1
            else:
                raise {'There must exist watermarks or poisons.'}
        img = Image.fromarray(img)
        img = self.transform(img)
        return img, target, idx

    def __len__(self) -> int:
        return 2 * len(self.data)