import copy
import random

import numpy as np
import PIL
from PIL import Image
from torchvision.transforms import functional as F
from torchvision.transforms import Compose
from attack.originalimagenet import Origdataset
import torchvision.transforms as transforms
from  torchvision.datasets import CIFAR10

from .base import *
from settings import base_args, base_config
args, config = base_args, base_config

class CleanDataset(Origdataset):

    def __init__(self,
                 benign_dataset):
        super(CleanDataset, self).__init__(
            args,
            args.datasets_root_dir,
            transform=benign_dataset.transform,
            target_transform=benign_dataset.target_transform,
            download=True)
        total_num = len(benign_dataset)
        self.transform =  self.transform = transforms.Compose([
                            transforms.Resize((args.img_size,args.img_size)),
                             transforms.ToTensor()])

    def __getitem__(self, index):
        
        path, target = self.samples[index]
        img = self.loader(path)
        #img = F.pil_to_tensor(img)
        #img = Image.fromarray(img.permute(1, 2, 0).numpy())
        img = self.transform(img)
        
        poisoned_target = target

        return img, {'label_orig':target, 'label_pois':poisoned_target}
    
    
class CleanDatasetCIFAR10(CIFAR10):

    def __init__(self,
                 benign_dataset):
        super(CleanDatasetCIFAR10, self).__init__(
            benign_dataset.root,
            benign_dataset.train,
            benign_dataset.transform,
            benign_dataset.target_transform,
            download=True)
        total_num = len(benign_dataset)
        self.transform =  self.transform = transforms.Compose([
                            transforms.Resize((args.img_size,args.img_size)),
                             transforms.ToTensor()])

    def __getitem__(self, index):
        
        img, target = self.data[index], int(self.targets[index])
        #img = F.pil_to_tensor(img)
        #img = Image.fromarray(img.permute(1, 2, 0).numpy())
        img = Image.fromarray(img)
        img = self.transform(img)
        
        poisoned_target = target

        return img, {'label_orig':target, 'label_pois':poisoned_target}
    


class Clean(Base):

    def __init__(self,
                 train_dataset,
                 test_dataset,
                 model,
                 loss,
                 schedule=None,
                 seed=0,
                 deterministic=False):
        
        self.model = model
        super(Clean, self).__init__(
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            model=model,
            loss=loss,
            schedule=schedule,
            seed=seed,
            deterministic=deterministic)
        
        if args.dataset == 'Cifar10':  
            
            self.poisoned_train_dataset = CleanDatasetCIFAR10(
                train_dataset)

            self.poisoned_test_dataset = CleanDatasetCIFAR10(
                test_dataset) 
        
        else:
             
            self.poisoned_train_dataset = CleanDataset(
                train_dataset)

            self.poisoned_test_dataset = CleanDataset(
                test_dataset)