import os

import numpy as np
import pandas as pd
import torch
import torchvision
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import transforms


def _read_dfs(path, dataset):
    df_list = []
    for mode in ['-', 'train', 'test']:
        if mode == '-':
            filename = f'{dataset}_R.dat'
        else:
            filename = f'{dataset}_{mode}_R.dat'
        file = os.path.join(path, dataset, filename)

        if os.path.exists(file):
            df = pd.read_csv(file, sep='\t', index_col=0)
            df_list.append(df.reset_index(drop=True))
    return df_list


def _split(num_data, ratio, seed):
    shuffled_index = np.arange(num_data)
    np.random.seed(seed)
    np.random.shuffle(shuffled_index)
    index1 = shuffled_index[:int(num_data * ratio)]
    index2 = shuffled_index[int(num_data * ratio):]
    return index1, index2


def _normalize(arr):
    avg = np.mean(arr, axis=0)
    std = np.std(arr, axis=0)
    arr2 = arr - avg
    arr2[:, std != 0] /= std[std != 0]
    return arr2


def _to_loader(x, y, batch_size, shuffle=False):
    x = torch.tensor(x)
    y_type = torch.long if y.dtype == np.int else torch.float
    y = torch.tensor(y, dtype=y_type)
    return DataLoader(TensorDataset(x, y), batch_size, shuffle)


def _to_img_dataset(data_path, dataset, train):
    if dataset == 'mnist':
        return torchvision.datasets.MNIST(
            os.path.join(data_path, 'torchvision'),
            train=train,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]))
    elif dataset == 'fashion':
        return torchvision.datasets.FashionMNIST(
            os.path.join(data_path, 'torchvision'),
            train=train,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ]))
    else:
        raise ValueError(dataset)


def read_uci(path, dataset, seed=138):
    if not os.path.exists(os.path.join(path, dataset)):
        raise ValueError(dataset)

    df_list = _read_dfs(path, dataset)

    if len(df_list) == 1:
        df = df_list[0]
        arr_x = df.iloc[:, :-1].values.astype(np.float32)
        arr_y = df.iloc[:, -1].values

        trn_idx, test_idx = _split(arr_x.shape[0], ratio=0.8, seed=seed)

        trn_x = arr_x[trn_idx]
        trn_y = arr_y[trn_idx]
        test_x = arr_x[test_idx]
        test_y = arr_y[test_idx]

    elif len(df_list) == 2:
        trn_df = df_list[0]
        test_df = df_list[1]

        trn_x = trn_df.iloc[:, :-1].values.astype(np.float32)
        trn_x = _normalize(trn_x)
        trn_y = trn_df.iloc[:, -1].values
        test_x = test_df.iloc[:, :-1].values.astype(np.float32)
        test_x = _normalize(test_x)
        test_y = test_df.iloc[:, -1].values

    else:
        raise ValueError(dataset)

    nx = trn_x.shape[1]
    ny = trn_y.max() + 1
    nd = trn_y.shape[0] + test_y.shape[0]

    return dict(trn_x=trn_x, trn_y=trn_y,
                test_x=test_x, test_y=test_y,
                nx=nx, ny=ny, nd=nd)


def read_as_dict(data_path, dataset, batch_size, seed=2019):
    np.random.seed(seed)

    def read_as_tensor(path_):
        return torch.from_numpy(np.load(path_))

    if dataset in {'mnist', 'fashion'}:
        trn_data = _to_img_dataset(data_path, dataset, train=True)
        test_data = _to_img_dataset(data_path, dataset, train=False)

        data_dict = dict(
            trn_loader=DataLoader(trn_data, batch_size),
            test_loader=DataLoader(test_data, batch_size),
            nd=70000,
            nx=784,
            ny=10)

    elif dataset == 'synthetic':
        path = os.path.join(data_path, dataset)
        trn_x = read_as_tensor(os.path.join(path, 'trn_x.npy'))
        test_x = read_as_tensor(os.path.join(path, 'test_x.npy'))
        trn_y = read_as_tensor(os.path.join(path, 'trn_y.npy'))
        test_y = read_as_tensor(os.path.join(path, 'test_y.npy'))

        data_dict = dict(
            trn_loader=DataLoader(TensorDataset(trn_x, trn_y), batch_size),
            test_loader=DataLoader(TensorDataset(test_x, test_y), batch_size),
            nd=trn_x.size(0) + test_x.size(0),
            nx=trn_x.size(1),
            ny=torch.max(trn_y).item() + 1)

    else:
        path = os.path.join(data_path, 'uci')
        data_dict = read_uci(path, dataset)
        trn_x = data_dict['trn_x']
        trn_y = data_dict['trn_y']
        test_x = data_dict['test_x']
        test_y = data_dict['test_y']

        data_dict['trn_loader'] = _to_loader(trn_x, trn_y, batch_size, shuffle=True)
        data_dict['test_loader'] = _to_loader(test_x, test_y, batch_size)

    return data_dict


def get_uci_datasets(path):
    with open(os.path.join(path, 'uci.txt')) as f:
        datasets = [e.strip() for e in f.readlines()]
    if 'molec-biol-protein-second' in datasets:
        datasets.remove('molec-biol-protein-second')
    return datasets
