import torch
import os
import numpy as np
import pandas as pd
import random

from torch.utils.data.sampler import SubsetRandomSampler
from functools import reduce
from operator import __or__
import torchvision.datasets as datasets
import torchvision.transforms as transforms


from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

def dataset_reg(batch_size, dataset, data_dir, valid_ratio, test_ratio = 0.1, seed=0, valid_seed=0, loader = True, mu_mode = 'zero', loading_mode = 'rbf'):

    assert dataset in ['Boston','Concrete','Energy', 'Kin8nm', 'Power', 'Wine'], "Error dataset; name:{}".format(dataset)
    
    ### dataset
    if dataset == 'Boston':
        data =  pd.read_csv(os.path.join(data_dir,"uci","housing.data"), header=None, sep="\s+")
        data_x = MinMaxScaler((0,1)).fit_transform(data.iloc[:, :-1].astype(np.float64))
        data_y = np.array(data.iloc[:,-1]).reshape(-1,1)

        x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=test_ratio, random_state=seed)
        if valid_ratio >0:
            x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_ratio, random_state=valid_seed)
            x_train, x_valid, x_test, y_train, y_valid, y_test = torch.tensor(x_train).float(), torch.tensor(x_valid).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_valid).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = x_valid.shape[0]
            validloader = DataLoader(TensorDataset(x_valid, y_valid), batch_size=valid_batch_size)
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], x_valid.shape[0], x_test.shape[0], x_train.shape[1], 1
        else: 
            x_train, x_test, y_train, y_test = torch.tensor(x_train).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            x_valid = None
            y_valid = None
            valid_batch_size = None
            validloader = None
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], None, x_test.shape[0], x_train.shape[1], 1

    elif dataset == "Concrete":
        data =  pd.read_csv(os.path.join(data_dir,"uci","Concrete_Data.csv"), header=None)
        data_x = MinMaxScaler((0,1)).fit_transform(data.iloc[:, :-1].astype(np.float64))
        data_y = np.array(data.iloc[:,-1]).reshape(-1,1)
        x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=test_ratio, random_state=seed)
        if valid_ratio >0:
            x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_ratio, random_state=valid_seed)
            x_train, x_valid, x_test, y_train, y_valid, y_test = torch.tensor(x_train).float(), torch.tensor(x_valid).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_valid).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = x_valid.shape[0]
            validloader = DataLoader(TensorDataset(x_valid, y_valid), batch_size=valid_batch_size)
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], x_valid.shape[0], x_test.shape[0], x_train.shape[1], 1
        else: 
            x_train, x_test, y_train, y_test = torch.tensor(x_train).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            x_valid = None
            y_valid = None
            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = None
            validloader = None
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], None, x_test.shape[0], x_train.shape[1], 1
    
    elif dataset == "Energy":
        data =  pd.read_csv(os.path.join(data_dir,"uci","ENB2012_data.csv"), header=None)
        data_x = MinMaxScaler((0,1)).fit_transform(data.iloc[:, :-1].astype(np.float64))
        data_y = np.array(data.iloc[:,-1]).reshape(-1,1)
        x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=test_ratio, random_state=seed)
        if valid_ratio >0:
            x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_ratio, random_state=valid_seed)
            x_train, x_valid, x_test, y_train, y_valid, y_test = torch.tensor(x_train).float(), torch.tensor(x_valid).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_valid).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = x_valid.shape[0]
            validloader = DataLoader(TensorDataset(x_valid, y_valid), batch_size=valid_batch_size)
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], x_valid.shape[0], x_test.shape[0], x_train.shape[1], 1
        else: 
            x_train, x_test, y_train, y_test = torch.tensor(x_train).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            x_valid = None
            y_valid = None
            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = None
            validloader = None
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], None, x_test.shape[0], x_train.shape[1], 1

    elif dataset == "Kin8nm":
        data =  pd.read_csv(os.path.join(data_dir,"uci","kin8nm.csv"), index_col = 0)
        data_x = MinMaxScaler((0,1)).fit_transform(data.iloc[:, :-1].astype(np.float64))
        data_y = np.array(data.iloc[:,-1]).reshape(-1,1)
        x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=test_ratio, random_state=seed)
        if valid_ratio >0:
            x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_ratio, random_state=valid_seed)
            x_train, x_valid, x_test, y_train, y_valid, y_test = torch.tensor(x_train).float(), torch.tensor(x_valid).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_valid).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = x_valid.shape[0]
            validloader = DataLoader(TensorDataset(x_valid, y_valid), batch_size=valid_batch_size)
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], x_valid.shape[0], x_test.shape[0], x_train.shape[1], 1
        else: 
            x_train, x_test, y_train, y_test = torch.tensor(x_train).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            x_valid = None
            y_valid = None
            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = None
            validloader = None
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], None, x_test.shape[0], x_train.shape[1], 1
            
            
    elif dataset == "Power":
        data =  pd.read_csv(os.path.join(data_dir,"uci","ccpp.csv"), sep=",")
        data_x = MinMaxScaler((0,1)).fit_transform(data.iloc[:, :-1].astype(np.float64))
        data_y = np.array(data.iloc[:,-1]).reshape(-1,1)
        x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=test_ratio, random_state=seed)
        if valid_ratio >0:
            x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_ratio, random_state=valid_seed)
            x_train, x_valid, x_test, y_train, y_valid, y_test = torch.tensor(x_train).float(), torch.tensor(x_valid).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_valid).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = x_valid.shape[0]
            validloader = DataLoader(TensorDataset(x_valid, y_valid), batch_size=valid_batch_size)
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], x_valid.shape[0], x_test.shape[0], x_train.shape[1], 1
        else: 
            x_train, x_test, y_train, y_test = torch.tensor(x_train).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            x_valid = None
            y_valid = None
            valid_batch_size = None
            validloader = None
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], None, x_test.shape[0], x_train.shape[1], 1
            

    elif dataset == "Wine":
        data =  pd.read_csv(os.path.join(data_dir,"uci","winequality-red.csv"), sep=";")
        data_x = MinMaxScaler((0,1)).fit_transform(data.iloc[:, :-1].astype(np.float64))
        data_y = np.array(data.iloc[:,-1]).reshape(-1,1)
        x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=test_ratio, random_state=seed)
        if valid_ratio >0:
            x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=valid_ratio, random_state=valid_seed)
            x_train, x_valid, x_test, y_train, y_valid, y_test = torch.tensor(x_train).float(), torch.tensor(x_valid).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_valid).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            valid_batch_size = x_valid.shape[0]
            validloader = DataLoader(TensorDataset(x_valid, y_valid), batch_size=valid_batch_size)
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], x_valid.shape[0], x_test.shape[0], x_train.shape[1], 1
        else: 
            x_train, x_test, y_train, y_test = torch.tensor(x_train).float(), torch.tensor(x_test).float(), torch.tensor(y_train).float(), torch.tensor(y_test).float()
            if batch_size == -1:
                train_batch_size = x_train.shape[0]
            else:
                train_batch_size = batch_size

            trainloader = DataLoader(TensorDataset(x_train, y_train), batch_size=train_batch_size)
            x_valid = None
            y_valid = None
            valid_batch_size = None
            validloader = None
            test_batch_size = x_test.shape[0]
            testloader = DataLoader(TensorDataset(x_test, y_test), batch_size=test_batch_size)
            n_train, n_valid, n_test, p_data, num_classes = x_train.shape[0], None, x_test.shape[0], x_train.shape[1], 1
    
    if loader == True:    
        return (trainloader, validloader, testloader), (n_train, n_valid, n_test, p_data, num_classes)
    else:
        return x_train, y_train, x_valid, y_valid, x_test, y_test 

def batch_dataset_robust(batch_size, in_dataset, data_dir, out_dataset):

    if in_dataset == "cifar10":
        corrupted_set = datasets.CIFAR10(os.path.join(data_dir,in_dataset), train = False, download = True, transform = transforms.Compose([transforms.ToTensor()]))
    elif in_dataset == "cifar100":
        corrupted_set = datasets.CIFAR100(os.path.join(data_dir,in_dataset), train = False, download = True, transform = transforms.Compose([transforms.ToTensor()]))
    corrupted_data_pth =  os.path.join(data_dir, "{}-c".format(in_dataset))

    if out_dataset =='svhn':
        out_testset = datasets.SVHN(os.path.join(data_dir, out_dataset), split = "test", download = True, transform = transforms.Compose([transforms.ToTensor()]))
    elif out_dataset =='lsun':
        print('....')

    out_testloader = torch.utils.data.DataLoader(out_testset, batch_size=batch_size, shuffle=False, num_workers=0)
    return corrupted_set, corrupted_data_pth, out_testloader

def batch_dataset(batch_size, dataset, data_dir, transform_valid = "train", augment = "standard", n_data=None, n_valid = 0, num_workers = 0):

    assert dataset in ['mnist','cifar10','cifar100'], "Error dataset; name:{}".format(dataset)
    assert augment in ['standard'], "Error augmentation; name:{}".format(augment)

    ############################################################################################################################################
    # Transforms
    ############################################################################################################################################
    
    if augment == "standard":
        if dataset == 'cifar10':
            # transforms.Normalize([125.3/255, 123.0/255, 113.9/255], [63.0/255, 62.1/255, 66.7/255])
            transform_train = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.RandomCrop(32, padding=4),
                                                  transforms.ToTensor()])
            transform_test = transforms.Compose([transforms.ToTensor()])
            if n_data == None:
                n_data = 50000
        elif dataset == 'cifar100':
            # transforms.Normalize([129.3/255, 124.1/255, 112.4/255], [68.2/255, 65.4/255, 70.4/255])
            transform_train = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.RandomCrop(32, padding=4),
                                                  transforms.ToTensor()])
            transform_test = transforms.Compose([transforms.ToTensor()]) 
            if n_data == None:
                n_data = 50000

    ############################################################################################################################################
    # Dataset (train / valid / test)
    ############################################################################################################################################
    if transform_valid == "train":
        transform_valid = transform_train
    else:
        transform_valid = transform_test
    
    n_train = n_data - n_valid
    indices = list(range(n_data))
        
    train_idx = torch.tensor(indices[:n_train])
    valid_idx = torch.tensor(indices[n_train:])

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    
    ### dataset
    if dataset == 'cifar10':        
        trainset = datasets.CIFAR10(os.path.join(data_dir,dataset), train = True, download = True, transform = transform_train)
        validset = datasets.CIFAR10(os.path.join(data_dir,dataset), train = True, download = True, transform = transform_valid)
        testset = datasets.CIFAR10(os.path.join(data_dir,dataset), train = False, download = True, transform = transform_test)
        num_classes = 10

    elif dataset =='cifar100':
        trainset = datasets.CIFAR100(os.path.join(data_dir,dataset), train = True, download = True, transform = transform_train)
        validset = datasets.CIFAR100(os.path.join(data_dir,dataset), train = True, download = True, transform = transform_valid)
        testset = datasets.CIFAR100(os.path.join(data_dir,dataset), train = False, download = True, transform = transform_test)
        num_classes = 100

    ### loader 
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler = train_sampler, num_workers=num_workers)
    validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, sampler = valid_sampler, num_workers=num_workers)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return trainloader, validloader, testloader, num_classes