import torch
from torchvision import transforms, datasets
from functools import partial
import os
import random
from collections import defaultdict
import numpy as np
from sklearn import datasets as datasets_sk
from PIL import Image

def add_noise(x, dequantize=False):
    """
    [0, 1] -> [0, 255] -> add noise -> [0, 1]
    """
    if dequantize:
        noise = x.new().resize_as_(x).uniform_()
        x = x * 255 + noise
        x = x / 256
    return x


class MNIST(datasets.MNIST):

    def __init__(self,
                 labels_per_class,
                 train,
                 transform,
                 download,
                 seed,
                 labeled=False,
                 pu=False,
                 pu_config=None):
        super().__init__(root="data",
                         train=train,
                         transform=transform,
                         download=download)
        if pu:
            print(pu)
            m = None 
            print(pu_config.use_classes)
            for c in pu_config.use_classes:
                m = (0 if m is None else m)+(self.targets==c)
            m = m==1
            self.targets = self.targets[m]
            self.data = self.data[m]
            m = None 
            for c in pu_config.positive_classes:
                m = (0 if m is None else m)+(self.targets==c)
            
            self.targets = m

        classwise = defaultdict(list)
        for idx, c in enumerate(self.targets):
            classwise[c.item()].append(idx)
        
        if train:
            labeled_idxs = []
            for c in classwise:
                random.Random(seed+c).shuffle(classwise[c])
                if not(pu) or (pu and c==1) or (pu and labels_per_class==-1):
                    labeled_idxs.extend(classwise[c][:labels_per_class])

            unlabeled_idxs = list(
                set(range(len(self.targets))) - set(labeled_idxs))
            self.labeled_idxs = labeled_idxs
            self.unlabeled_idxs = unlabeled_idxs
            self.targets[self.unlabeled_idxs] = -1

            if labeled:
                self.data = self.data[labeled_idxs]
                self.targets = self.targets[labeled_idxs]
            else:
                self.data = self.data[unlabeled_idxs]
                self.targets = self.targets[unlabeled_idxs]
        else:
            self.labeled_idxs = list(range(len(self.targets)))
            self.unlabeled_idxs = []

    def __len__(self):
        return len(self.data)


class SVHN(datasets.SVHN):

    def __init__(self,
                 labels_per_class,
                 train,
                 transform,
                 download,
                 seed,
                 labeled=False,
                 pu=False,
                 pu_config=None):
        super().__init__(root="data",
                         split="train" if train else "test",
                         transform=transform,
                         download=True)
        
        classwise = defaultdict(list)

        self.targets = self.labels
        if pu:
            print(pu)
            m = None 
            print(pu_config.use_classes)
            for c in pu_config.use_classes:
                m = (0 if m is None else m)+(self.targets==c)
            m = m==1
            self.targets = self.targets[m]
            self.data = self.data[m]
            m = None 
            for c in pu_config.positive_classes:
                m = (0 if m is None else m)+(self.targets==c)
            
            self.targets = m
        for idx, c in enumerate(self.targets):
            classwise[c.item()].append(idx)
        if train:
            labeled_idxs = []
            for c in classwise:
                random.Random(seed+c).shuffle(classwise[c])
                if not(pu) or (pu and c==1) or (pu and labels_per_class==-1):
                    labeled_idxs.extend(classwise[c][:labels_per_class])

            unlabeled_idxs = list(
                set(range(len(self.targets))) - set(labeled_idxs))
            self.labeled_idxs = labeled_idxs
            self.unlabeled_idxs = unlabeled_idxs
            self.targets[self.unlabeled_idxs] = -1

            if labeled:
                self.data = self.data[labeled_idxs]
                self.targets = self.targets[labeled_idxs]
            else:
                self.data = self.data[unlabeled_idxs]
                self.targets = self.targets[unlabeled_idxs]
        else:
            self.labeled_idxs = list(range(len(self.targets)))
            self.unlabeled_idxs = []

        self.labels = self.targets

    def __len__(self):
        return len(self.data)


class CIFAR10(datasets.CIFAR10):

    def __init__(self,
                 labels_per_class,
                 train,
                 transform,
                 download,
                 seed,
                 labeled=False):
        super().__init__(root="data",
                         train=train,
                         transform=transform,
                         download=True)
        classwise = defaultdict(list)

        for idx, c in enumerate(self.targets):
            classwise[c].append(idx)
        if train:
            self.targets = torch.LongTensor(self.targets)
            labeled_idxs = []
            for c in classwise:
                random.Random(seed+c).shuffle(classwise[c])
                labeled_idxs.extend(classwise[c][:labels_per_class])

            unlabeled_idxs = list(
                set(range(len(self.targets))) - set(labeled_idxs))
            self.labeled_idxs = labeled_idxs
            self.unlabeled_idxs = unlabeled_idxs
            self.targets[self.unlabeled_idxs] = -1
            if labels_per_class!=-1:
                if labeled:
                    self.data = self.data[labeled_idxs]
                    self.targets = self.targets[labeled_idxs]
                else:
                    self.data = self.data[unlabeled_idxs]
                    self.targets = self.targets[unlabeled_idxs]
        else:
            self.labeled_idxs = list(range(len(self.targets)))
            self.unlabeled_idxs = []

    def __len__(self):
        return len(self.data)


def get_dataset(config, uniform_dequantization=False, evaluation=False):
    if config.data.dataset == "MNIST":
        train_transform = transforms.Compose([
            transforms.Pad(2),
            transforms.ToTensor(),
            partial(add_noise, dequantize=uniform_dequantization),
        ])
        test_transform = transforms.Compose([
            transforms.Pad(2),
            transforms.ToTensor(),
            partial(add_noise, dequantize=uniform_dequantization),
        ])
        train_labelled = MNIST(train=True,
                               labels_per_class=config.data.labels_per_class,
                               transform=train_transform,
                               download=True,
                               seed=config.seed,
                               labeled=True,
                               pu = config.data.pu,
                               pu_config=config.data.pu_config)
        train_unlabelled = MNIST(train=True,
                                 labels_per_class=config.data.labels_per_class,
                                 transform=train_transform,
                                 download=True,
                                 seed=config.seed,
                                 labeled=False,
                                 pu = config.data.pu,
                                 pu_config=config.data.pu_config)
        train_for_score = MNIST(train=True,
                              labels_per_class=-1,
                              transform=test_transform,
                              download=True,
                              seed=config.seed,
                              labeled=True,
                              pu = config.data.pu,
                              pu_config=config.data.pu_config)
        test = MNIST(train=False,
                     labels_per_class=config.data.labels_per_class,
                     transform=test_transform,
                     download=True,
                     seed=config.seed,
                     pu = config.data.pu,
                     pu_config=config.data.pu_config)
    
        train_loader = {
            "labeled":
            torch.utils.data.DataLoader(
                train_labelled,
                batch_size=config.training.labeled_batch_size //
                torch.cuda.device_count(),
                num_workers=4,
                shuffle=True),
            "unlabeled":
            torch.utils.data.DataLoader(
                train_unlabelled,
                batch_size=(config.training.batch_size -
                            config.training.labeled_batch_size) //
                torch.cuda.device_count(),
                num_workers=4,
                shuffle=True),
            'score': torch.utils.data.DataLoader(
                train_for_score,
                batch_size=config.training.batch_size // torch.cuda.device_count(),
                num_workers=4,
                shuffle=True),
        }
        test_loader = torch.utils.data.DataLoader(test,
                                                  batch_size=config.training.batch_size //
                                                  torch.cuda.device_count(),
                                                  num_workers=4)
    
    elif config.data.dataset == 'CIFAR10':
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(32,antialias=True),
            transforms.ToTensor(),
            partial(add_noise, dequantize=uniform_dequantization),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            partial(add_noise, dequantize=uniform_dequantization),
        ])
        train_labelled = CIFAR10(train=True,
                                 labels_per_class=config.data.labels_per_class,
                                 transform=train_transform,
                                 download=True,
                                 seed=config.seed,
                                 labeled=True)
        train_unlabelled = CIFAR10(
                train=True,
                labels_per_class=config.data.labels_per_class,
                transform=train_transform,
                download=True,
                seed=config.seed,
                labeled=False)
        train_for_score = CIFAR10(train=True,
                                 labels_per_class=config.data.labels_per_class,
                                 transform=train_transform,
                                 download=True,
                                 seed=config.seed,
                                 labeled=True)
        test = CIFAR10(train=False,
                       labels_per_class=config.data.labels_per_class,
                       transform=test_transform,
                       download=True,
                       seed=config.seed,)
        train_loader = {
            'score': torch.utils.data.DataLoader(
                train_for_score,
                batch_size= config.training.batch_size//torch.cuda.device_count(),
                num_workers=4,
                shuffle=True),
            "labeled":
            torch.utils.data.DataLoader(
                train_labelled,
                batch_size=config.training.labeled_batch_size //
                torch.cuda.device_count(),shuffle=True,
                num_workers=4),
            "unlabeled":
            torch.utils.data.DataLoader(
                train_unlabelled,
                batch_size=(config.training.batch_size -
                            config.training.labeled_batch_size) //
                torch.cuda.device_count(),shuffle=True,
                num_workers=4),
            }
        
        test_loader = torch.utils.data.DataLoader(test,
                                                  batch_size=config.training.batch_size //
                                                  torch.cuda.device_count(),
                                                  num_workers=4,shuffle=True)
    elif config.data.dataset == 'SVHN':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            partial(add_noise, dequantize=uniform_dequantization),
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            partial(add_noise, dequantize=uniform_dequantization),
        ])
        train_labelled = SVHN(train=True,
                              labels_per_class=config.data.labels_per_class,
                              transform=train_transform,
                              download=True,seed=config.seed,
                              labeled=True,
                              pu = config.data.pu,
                              pu_config=config.data.pu_config)
        train_unlabelled = SVHN(train=True,
                                labels_per_class=config.data.labels_per_class,
                                transform=train_transform,
                                download=True,seed=config.seed,
                                labeled=False,
                                pu = config.data.pu,
                                pu_config=config.data.pu_config)
        train_for_score = SVHN(train=True,
                              labels_per_class=-1,
                              transform=test_transform,
                              download=True,seed=config.seed,
                              labeled=True,
                              pu = config.data.pu,
                              pu_config=config.data.pu_config)
        test = SVHN(train=False,
                    labels_per_class=config.data.labels_per_class,
                    transform=test_transform,
                    download=True,seed=config.seed,
                    pu = config.data.pu,
                    pu_config=config.data.pu_config)

        train_loader = {
            "labeled":
            torch.utils.data.DataLoader(
                train_labelled,
                batch_size=config.training.labeled_batch_size //
                torch.cuda.device_count(),shuffle=True,
                num_workers=4),
            "unlabeled":
            torch.utils.data.DataLoader(
                train_unlabelled,
                batch_size=(config.training.batch_size -
                            config.training.labeled_batch_size) //
                torch.cuda.device_count(),shuffle=True,
                num_workers=4),
            'score': torch.utils.data.DataLoader(
                train_for_score,
                batch_size=config.training.batch_size//torch.cuda.device_count(),
                num_workers=4,
                shuffle=True),
        }
        test_loader = torch.utils.data.DataLoader(test,
                                                  batch_size=config.training.batch_size //
                                                  torch.cuda.device_count(),shuffle=True,
                                                  num_workers=4)
    else:
        raise Exception('Unknown Dataset')
    return train_loader, test_loader
