import copy
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
import os
from torchvision.utils import save_image
import torchvision.utils as tvu
import matplotlib.pyplot as plt
import random


def load_idx_slice_from_list(rawlist, slice_idx):
    return [rawlist[i] for i in slice_idx]


class PoisonLabelDataset(Dataset):
    """Poison-Label dataset wrapper.

    Args:
        dataset (Dataset): The dataset to be wrapped.
        transform (callable): The backdoor transformations.
        poison_idx (np.array): An 0/1 (clean/poisoned) array with
            shape `(len(dataset), )`.
        target_label (int): The target label.
    """

    def __init__(self, dataset, transform, poison_idx, target_label):
        super(PoisonLabelDataset, self).__init__()
        print("copy start")
        # self.dataset = copy.deepcopy(dataset)
        self.dataset = dataset
        print("copy end")
        self.train = self.dataset.train
        if self.train:
            print(f"num of train samples:{len(self.dataset.data)}")
            print('Num of poison samples:',np.sum(poison_idx))
            self.data = self.dataset.data
            self.targets = self.dataset.targets
            self.poison_idx = poison_idx
        else:
            # Only fetch poison data when testing.
            print(f"num of test samples:{len(self.dataset.data)}")
            print(np.nonzero(poison_idx)[0])
            print(poison_idx)
            print(np.sum(poison_idx))

            self.data = load_idx_slice_from_list(self.dataset.data, np.nonzero(poison_idx)[0])
            self.targets = load_idx_slice_from_list(self.dataset.targets, np.nonzero(poison_idx)[0])
            
            self.poison_idx = poison_idx[poison_idx == 1]
        self.pre_transform = self.dataset.pre_transform
        self.primary_transform = self.dataset.primary_transform
        self.remaining_transform = self.dataset.remaining_transform
        self.prefetch = self.dataset.prefetch
        if self.prefetch:
            self.mean, self.std = self.dataset.mean, self.dataset.std

        self.bd_transform = transform
        self.target_label = target_label

        # self.purified_imgs = {}

    def __getitem__(self, index):
        if isinstance(self.data[index], str):
            with open(self.data[index], "rb") as f:
                img = np.array(Image.open(f).convert("RGB"))
        else:
            img = self.data[index]
        target = self.targets[index]
        poison = 0
        origin = target  # original target

        if self.poison_idx[index] == 1:
            img = self.bd_first_augment(img, bd_transform=self.bd_transform)
            target = self.target_label
            poison = 1
        else:
            img = self.bd_first_augment(img, bd_transform=None)
        item = {"img": img, "target": target, "poison": poison, "origin": origin, "index": index, "num_imgs": 1}
        # print(item["img"].shape)
        return item

    def __len__(self):
        return len(self.data)

    def bd_first_augment(self, img, bd_transform=None, clean_test=False):
        # Pre-processing transformation (HWC ndarray->HWC ndarray).
        img = Image.fromarray(img)
        img = self.pre_transform(img)
        
        # Backdoor transformation
        img = np.array(img)
        if bd_transform is not None:
            if not clean_test:
                img = bd_transform(img)
            else:
                img = bd_transform(img, clean_test=True)
                
        img = Image.fromarray(img)
        
        img = self.primary_transform(img)
        img = self.remaining_transform(img)
        # print(self.primary_transform,self.remaining_transform)

        # print(self.prefetch)
        if self.prefetch:
            # HWC ndarray->CHW tensor with C=3.
            img = np.rollaxis(np.array(img, dtype=np.uint8), 2)
            img = torch.from_numpy(img)

        return img

    def update_purified_imgs(self, imgs, indices):
        self.purified_imgs.update({indices[i]: imgs[i] for i in range(len(indices))})
        return


class MixMatchDataset(Dataset):
    """Semi-supervised MixMatch dataset.

    Args:
        dataset (Dataset): The dataset to be wrapped.
        semi_idx (np.array): An 0/1 (labeled/unlabeled) array with shape ``(len(dataset), )``.
        labeled (bool): If True, creates dataset from labeled set, otherwise creates from unlabeled
            set (default: True).
    """

    def __init__(self, dataset, semi_idx, labeled=True):
        super(MixMatchDataset, self).__init__()
        print('copy start')
        self.dataset = dataset
        print('copy end')
        if labeled:
            self.semi_indice = np.nonzero(semi_idx == 1)[0]
        else:
            self.semi_indice = np.nonzero(semi_idx == 0)[0]
        self.labeled = labeled
        self.prefetch = self.dataset.prefetch
        if self.prefetch:
            self.mean, self.std = self.dataset.mean, self.dataset.std

    def save_images(self, img1, img2, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        save_image(img1, os.path.join(save_dir, "img1.png"))
        save_image(img2, os.path.join(save_dir, "img2.png"))

    def __getitem__(self, index):
        if self.labeled:
            item = self.dataset[self.semi_indice[index]]
            item["labeled"] = True
        else:
            item1 = self.dataset[self.semi_indice[index]]
            item2 = self.dataset[self.semi_indice[index]]
            img1, img2 = item1.pop("img"), item2.pop("img")

            item1.update({"img1": img1, "img2": img2, "num_imgs": 2})
            item = item1
            item["labeled"] = False
        return item

    def __len__(self):
        return len(self.semi_indice)
