import sys
import torch
from torchvision import datasets, transforms
from base import BaseDataLoader
from data_loader.cifar10 import get_cifar10
from data_loader.cifar100 import get_cifar100
from parse_config import ConfigParser
from PIL import Image
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset

from data_loader.augmentations import Augmentation, CutoutDefault
from data_loader.augmentation_archive import autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10#, autoaug_imagenet_policy, svhn_policies


class NoisyDataset(Dataset):
    def __init__(self, is_train, x, y, num_classes, transform_train, transform_test, e0, e1):
        self.x = x
        self.y = y
        self.num_classes = num_classes
        self.transform_train = transform_train
        self.transform_test = transform_test
        self.e0 = e0
        self.e1 = e1
        self.is_train = is_train

        if is_train:
            t_ = self._load_noise_label()
            self.y = t_.tolist()

    def _load_transition_matrix(self):
        return torch.tensor([[1-self.e1, self.e1],
                             [self.e0, 1-self.e0]])

    def _load_noise_label(self):
        self.y = torch.tensor(self.y, dtype=torch.long)
        y_onehot = torch.nn.functional.one_hot(self.y, self.num_classes).type(torch.float32)
        transition_matrix = self._load_transition_matrix()
        y_noisy = torch.matmul(y_onehot, transition_matrix).squeeze()
        samples = torch.multinomial(y_noisy, num_samples=1)
        return samples

    def __getitem__(self, index):

        if self.is_train:
            sample = (self.x[index], self.x[(index+1)%self.x.shape[0]], self.x[index], self.y[index], 0, 0)
        else:
            sample = (self.x[index], self.y[index], 0, 0)

        if self.is_train:
            if self.transform_train is not None:
                sample = self.transform_train(sample)
        else:
            if self.transform_test is not None:
                sample = self.transform_test(sample)
        return sample

    def __len__(self):
        return len(self.x)

class ToTensor_t(object):
    """ Convert sample to Tensors. """
    def __call__(self, sample):
        x1, x2, x3, y, z1, z2 = sample
        return (torch.tensor(x1).float(),
                torch.tensor(x2).float(),
                torch.tensor(x3).float(),
                torch.tensor(y),
                z1,
                z2)

class ToTensor_v(object):
    """ Convert sample to Tensors. """
    def __call__(self, sample):
        x1, y, z1, z2 = sample
        return (torch.tensor(x1).float(),
                torch.tensor(y),
                z1,
                z2)


class BinaryCancerDataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True, T=None):
        bc = sklearn.datasets.load_breast_cancer()
        x = bc.data  # [:35]
        y = bc.target
        scaler = StandardScaler()
        scaler = scaler.fit(x)
        x = scaler.transform(x)
        self.num_features = x.shape[1]
        # Split the dataset in training and test parts
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0, stratify=y, shuffle=True)
        print("X_train.shape: ", x_train.shape)
        print("X_test.shape: ", x_test.shape)
        print("y_train.shape: ", y_train.shape)
        print("y_test.shape: ", y_test.shape)
        composed_transform_train = transforms.Compose([ToTensor_t()])
        composed_transform_test = transforms.Compose([ToTensor_v()])
        self.train_dataset = NoisyDataset(is_train=True, x=x_train, y=y_train,
                                     num_classes=2, transform_train=composed_transform_train, transform_test=composed_transform_test, e0=T[0], e1=T[1])
        self.val_dataset = NoisyDataset(is_train=False, x=x_test, y=y_test,
                                    num_classes=2, transform_train=composed_transform_train, transform_test=composed_transform_test, e0=T[0], e1=T[1])
        super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
                         val_dataset=self.val_dataset)

    def run_loader(self):
        pass


class CIFAR10DataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0,  training=True, num_workers=4,  pin_memory=True, T=None):
        config = ConfigParser.get_instance()
        cfg_trainer = config['trainer']
        
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        if config['train_loss']['args']['ratio_consistency'] > 0:
            transform_train_aug = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
            if config['data_augmentation']['type'] is not None:
                autoaug = transforms.Compose([])
                # if isinstance(cfg_trainer['aug'], list):
                #     autoaug.transforms.insert(0, Augmentation(C.get()['aug']))
                # else:
                if config['data_augmentation']['type'] == 'fa_reduced_cifar10':
                    autoaug.transforms.insert(0, Augmentation(fa_reduced_cifar10()))
                elif config['data_augmentation']['type'] == 'autoaug_cifar10':
                    autoaug.transforms.insert(0, Augmentation(autoaug_paper_cifar10()))
                elif config['data_augmentation']['type'] == 'autoaug_extend':
                    autoaug.transforms.insert(0, Augmentation(autoaug_policy()))
                elif config['data_augmentation']['type'] == 'default':
                    pass
                else:
                    raise ValueError('not found augmentations. %s' % config['data_augmentation']['type'])
                transform_train_aug.transforms.insert(0, autoaug)

                if config['data_augmentation']['cutout'] > 0:
                    transform_train_aug.transforms.append(CutoutDefault(config['data_augmentation']['cutout']))
        else:
            transform_train_aug = None

        self.data_dir = data_dir

        noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
        
        self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_train_aug=transform_train_aug, 
                                                           transform_val=transform_val, noise_file=noise_file)

        super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
                         val_dataset=self.val_dataset)

    def run_loader(self):
        pass


class CIFAR100DataLoader(BaseDataLoader):
    def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True,num_workers=4, pin_memory=True, T=None):
        config = ConfigParser.get_instance()
        cfg_trainer = config['trainer']
        
        transform_train = transforms.Compose([
                #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
            ])
        transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        transform_train_aug = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        if config['data_augmentation']['type'] is not None:

            autoaug = transforms.Compose([])
            # if isinstance(cfg_trainer['aug'], list):
            #     autoaug.transforms.insert(0, Augmentation(C.get()['aug']))
            # else:
            if config['data_augmentation']['type'] == 'fa_reduced_cifar10':
                autoaug.transforms.insert(0, Augmentation(fa_reduced_cifar10()))
            elif config['data_augmentation']['type'] == 'autoaug_cifar10':
                autoaug.transforms.insert(0, Augmentation(autoaug_paper_cifar10()))
            elif config['data_augmentation']['type'] == 'autoaug_extend':
                autoaug.transforms.insert(0, Augmentation(autoaug_policy()))
            elif config['data_augmentation']['type'] == 'default':
                pass
            else:
                raise ValueError('not found augmentations. %s' % config['data_augmentation']['type'])
            transform_train_aug.transforms.insert(0, autoaug)
            # transform_train.transforms.insert(0, autoaug)

            if config['data_augmentation']['cutout'] > 0:
                transform_train_aug.transforms.append(CutoutDefault(config['data_augmentation']['cutout']))

        self.data_dir = data_dir
        config = ConfigParser.get_instance()

        noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])

        self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
                                                           transform_train=transform_train, transform_train_aug=transform_train_aug, 
                                                           transform_val=transform_val, noise_file = noise_file)

        super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
                         val_dataset = self.val_dataset)

    def run_loader(self):
        pass
