import time
import random 
import numpy as np
import torch 
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import math
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torch.utils.data import random_split
import os
from PIL import Image
import torchvision.models as models
import torch.nn.functional as F
from layers.feat_noise import Noise
from tqdm import tqdm

from utils import report


class ConditionalMNIST(dsets.MNIST):
    def __init__(self, root, class_ix, train=True, transform=None, target_transform=None, download=False):
        super(ConditionalMNIST, self).__init__(root, train=train, 
                                          transform=transform, target_transform=target_transform, 
                                          download=download)
        self.class_ix = class_ix

        t_ix = []
        for i in range(len(self.data)):
            if int(self.targets[i].item()) == self.class_ix:
                t_ix.append(i)
            
        new_data = torch.zeros(*[len(t_ix)] + list(self.data[0].shape), dtype=torch.uint8)
        new_targets = torch.zeros(len(t_ix), dtype=torch.uint8)
        
        for i, ti in tqdm(enumerate(t_ix)):
            new_data[i, :, :] = self.data[ti, :, :]
            new_targets[i] = self.targets[ti]
            
        self.data = new_data
        self.targets = new_targets
            

class ConditionalCIFAR10(dsets.CIFAR10):
    def __init__(self, root, class_ix, train=True, transform=None, target_transform=None, download=False):
        super(ConditionalCIFAR10, self).__init__(root, train=train, 
                                            transform=transform, target_transform=target_transform, 
                                            download=download)
        self.class_ix = class_ix

        t_ix = []
        for i in range(len(self.data)):
            if int(self.targets[i]) == self.class_ix:
                t_ix.append(i)
            
        new_data = np.zeros([len(t_ix)] + list(self.data[0].shape)).astype(np.uint8)
        new_targets = np.zeros(len(t_ix))
        
        for i, ti in tqdm(enumerate(t_ix)):
            new_data[i] = self.data[ti]
            new_targets[i] = self.targets[ti]
            
        self.data = new_data
        self.targets = new_targets
