import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms


def split_and_normalize_as_tensors(X, y, test_size=0.25, seed=0):

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed
    )

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    X_train = torch.tensor(X_train).float()
    y_train = torch.tensor(y_train).float()
    X_test = torch.tensor(X_test).float()
    y_test = torch.tensor(y_test).float()

    return X_train, X_test, y_train, y_test


def train_val_split_tensors(X, y, val_size=0.25, seed=0):
    torch.manual_seed(seed)
    perm = torch.randperm(len(y))
    train_idx = perm[:int(val_size * len(y))]
    val_idx = perm[int(val_size * len(y)):]
    return X[train_idx], X[val_idx], y[train_idx], y[val_idx]


def subsample(X, y, n=1000):
    mask = np.random.permutation(len(y))[:n]
    return X[mask], y[mask]


def load_data(dataset):

    if dataset == 'abalone':
        X, y = datasets.fetch_openml('abalone', return_X_y=True)
        X = X.values[:, 1:].astype(np.float32)
        y = y.values.astype(np.float32) - 1

        sub_mask = y < 10
        X = X[sub_mask, :]
        y = y[sub_mask]

        X, y = subsample(X, y, n=1000)
        return X, y

    if dataset == 'car':
        X, y = datasets.fetch_openml('car', return_X_y=True)
        
        remap = {
            'buying': {'low': 0, 'med': 1, 'high': 2, 'vhigh': 3},
            'maint': {'low': 0, 'med': 1, 'high': 2, 'vhigh': 3},
            'doors': {'2': 0, '3': 1, '4': 2, '5more': 3},
            'persons': {'2': 0, '4': 1, 'more': 2},
            'lug_boot': {'small': 0, 'med': 1, 'big': 2},
            'safety': {'low': 0, 'med': 1, 'high': 2},
        }
        X = X.replace(remap)
        X = X.values.astype(np.float32)
        y = (y == 'P').values.astype(np.float32)
        print(X.shape, y.shape)
        return X, y

    if dataset == 'iris':
        X, y = datasets.load_iris(return_X_y=True)
        return X, y

    if dataset == 'wine':
        data_file = './data/classification/wine.csv'
        df = pd.read_csv(data_file, header=None)
        X = df.loc[:, 1:].values.astype(np.float32)
        y = (df[0].values - 1).astype(np.float32)
        return X, y

    if dataset == 'balance-scale':
        data_file = './data/classification/balance-scale.csv'
        df = pd.read_csv(data_file, header=None)
        X = df.loc[:, 1:].values.astype(np.float32)
        y = df[0].values
        y[y == 'L'] = 0.
        y[y == 'R'] = 1.
        y[y == 'B'] = 2.
        y = y.astype(np.float32)
        return X, y

    if dataset == 'transfusion':
        data_file = './data/classification/transfusion.csv'
        df = pd.read_csv(data_file)
        X = df.iloc[:, :-1].values.astype(np.float32)
        y = (df['whether he/she donated blood in March 2007'].values).astype(np.float32)
        return X, y


def load_torch_data(dataset, eval_on='test'):

    if dataset == 'mnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        train_ds = torchvision.datasets.MNIST(
            './data', train=True, download=True, transform=transform
        )
        test_ds = torchvision.datasets.MNIST(
            './data', train=False, download=False, transform=transform
        )

    if dataset == 'fmnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        train_ds = torchvision.datasets.FashionMNIST(
            './data', train=True, download=True, transform=transform
        )
        test_ds = torchvision.datasets.FashionMNIST(
            './data', train=False, download=False, transform=transform
        )

    elif dataset == 'cifar10':
        # resizing to 224x224 for deep metric learning experiment
        # other cifar10 experiments use 32x32 with resnet-20, comment/uncomment as appropriate
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        train_ds = torchvision.datasets.CIFAR10(
            './data', train=True, download=True, transform=transform
        )
        test_ds = torchvision.datasets.CIFAR10(
            './data', train=False, transform=transform
        )

    elif dataset == 'svhn':
        # transform = transforms.Compose([
        #     transforms.ToTensor(),
        #     transforms.Normalize((0.4380, 0.4440, 0.4730), (0.1980, 0.2010, 0.1970)),
        # ])
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        train_ds = torchvision.datasets.SVHN(
            './data', split='train', download=True, transform=transform
        )
        test_ds = torchvision.datasets.SVHN(
            './data', split='test', download=True, transform=transform
        ) 

    elif dataset == 'stl10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            # transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713)),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        train_ds = torchvision.datasets.STL10(
            './data', split='train', download=True, transform=transform
        )
        test_ds = torchvision.datasets.STL10(
            './data', split='test', download=True, transform=transform
        )

    elif dataset == 'cars196':
        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

        train_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            normalize
        ])

        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            normalize
        ])

        train_ds = torchvision.datasets.StanfordCars(
            './data', split='train', download=True, transform=train_transform
        )
        test_ds = torchvision.datasets.StanfordCars(
            './data', split='test', download=True, transform=test_transform
        )

    if eval_on == 'valid':
        torch.manual_seed(0)
        train_len = int(0.8 * len(train_ds))
        train_ds, test_ds = torch.utils.data.random_split(
            train_ds, [train_len, len(train_ds) - train_len]
        )
    elif eval_on == 'test':
        train_ds, test_ds = train_ds, test_ds
    else:
        raise Exception('Invalid eval set specified.')

    return train_ds, test_ds
