import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from sklearn.neighbors import KNeighborsRegressor
import pandas as pd


class Custom_Dataset(Dataset):

	def __init__(self, X, y, return_idx=True, device=None, transform=None):
		self.data = X.to(device)
		self.targets = y.to(device)
		self.count = len(X)
		self.device = device
		self.transform = transform
		self.return_idx = return_idx

	def __len__(self):
		return self.count

	def __getitem__(self, idx):
		if self.transform:
			data = self.transform(self.data[idx])
			targets = self.targets[idx]
		else:
			data = self.data[idx]
			targets = self.targets[idx]
		if self.return_idx:
			return data, targets, idx
		else:
			return data, targets


def load_housing(numdp, seed, raw_data=False, **kwargs):
    # function that imputes a dataframe 
    def impute_knn(df):
        ''' inputs: pandas df containing feature matrix '''
        ''' outputs: dataframe with NaN imputed '''
        # imputation with KNN unsupervised method

        # separate dataframe into numerical/categorical
        ldf = df.select_dtypes(include=[np.number])           # select numerical columns in df
        ldf_putaside = df.select_dtypes(exclude=[np.number])  # select categorical columns in df
        # define columns w/ and w/o missing data
        cols_nan = ldf.columns[ldf.isna().any()].tolist()         # columns w/ nan 
        cols_no_nan = ldf.columns.difference(cols_nan).values     # columns w/o nan 

        for col in cols_nan:                
            imp_test = ldf[ldf[col].isna()]   # indicies which have missing data will become our test set
            imp_train = ldf.dropna()          # all indicies which which have no missing data 
            model = KNeighborsRegressor(n_neighbors=5)  # KNR Unsupervised Approach
            knr = model.fit(imp_train[cols_no_nan], imp_train[col])
            ldf.loc[df[col].isna(), col] = knr.predict(imp_test[cols_no_nan])
        
        return pd.concat([ldf,ldf_putaside],axis=1)
    df = pd.read_csv('./dataset/housing.csv')
    del df['ocean_proximity']

    # Call function that imputes missing data
    df2 = impute_knn(df)
    # looks like we have a full feature matrix
    df2.info()

    # split test and train dataset
    trdata, tedata = train_test_split(df2,test_size=5000, train_size=numdp,random_state=seed)

    maxval2 = trdata['median_house_value'].max() # get the maximum value
    trdata_upd = trdata[trdata['median_house_value'] != maxval2] 
    tedata_upd = tedata[tedata['median_house_value'] != maxval2]
    # Make a feature that contains both longtitude & latitude
    trdata_upd['diag_coord'] = (trdata_upd['longitude'] + trdata_upd['latitude'])         # 'diagonal coordinate', works for this coord
    trdata_upd['bedperroom'] = trdata_upd['total_bedrooms']/trdata_upd['total_rooms']     # feature w/ bedrooms/room ratio
    # update test data as well
    tedata_upd['diag_coord'] = (tedata_upd['longitude'] + tedata_upd['latitude'])
    tedata_upd['bedperroom'] = tedata_upd['total_bedrooms']/tedata_upd['total_rooms']     # feature w/ bedrooms/room ratio

    alldata = df2
    del alldata['median_house_value']
    trlabel, telabel = trdata['median_house_value'], tedata['median_house_value']
    del trdata['median_house_value'], tedata['median_house_value']

    scaler = StandardScaler()
    scaler.fit(alldata)
    trdata, tedata = scaler.transform(trdata), scaler.transform(tedata)
    
    label_mean, label_std = np.mean(telabel), np.std(telabel)
    trlabel, telabel = (trlabel - label_mean)/label_std, (telabel - label_mean)/label_std
    
    trdata, tedata, trlabel, telabel = np.array(trdata), np.array(tedata), np.array(trlabel), np.array(telabel)

    if raw_data:
        return trdata, tedata, trlabel, telabel, None
    else:
        return trdata, tedata, trlabel, telabel


def load_cifar10(numdp, seed, flatten=True, raw_data=False, **kwargs):

    # Data
    print('==> Preparing data..')

    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    means = np.array([0.4914, 0.4822, 0.4465]).reshape([1,1,1,3])
    stds = np.array([0.2023, 0.1994, 0.2010]).reshape([1,1,1,3])

    train_data = datasets.CIFAR10(
        root='./dataset', train=True, download=True)

    test_data = datasets.CIFAR10(
        root='./dataset', train=False, download=True)
    trdata, trlabel, tedata, telabel = train_data.data/255, train_data.targets, test_data.data/255, test_data.targets
    trdata, tedata = (trdata - means)/stds, (tedata - means)/stds
    trdata, tedata = trdata.transpose([0,3,1,2]), tedata.transpose([0,3,1,2])
    if flatten:
        trdata, tedata = trdata.reshape(trdata.shape[0],-1), tedata.reshape(tedata.shape[0],-1)
        
    # trlabel, telabel = np.eye(10)[trlabel], np.eye(10)[telabel]

    # split test and train dataset
    trdata, _, trlabel, _  = train_test_split(trdata, trlabel,test_size=10000, train_size=numdp,random_state=seed)
    trdata, tedata, trlabel, telabel = np.array(trdata), np.array(tedata), np.array(trlabel), np.array(telabel)
    
    # print(np.isnan(np.sum(trdata)), np.isfinite(np.sum(trdata)),np.isnan(np.sum(tedata)), np.isfinite(np.sum(tedata)))
    # raise NotImplementedError
    # raise KeyError

    if raw_data:
        return trdata, tedata, trlabel, telabel, None
    else:
        return trdata, tedata, trlabel, telabel


def load_mnist(numdp, seed, flatten=True, raw_data=False, **kwargs):
    (train_x_raw, train_y_raw), (test_x_raw, test_y_raw) = tf.keras.datasets.mnist.load_data(
        path='mnist.npz'
    )
    
    train_x_raw, test_x_raw = train_x_raw.reshape(train_x_raw.shape[0],-1), test_x_raw.reshape(test_x_raw.shape[0],-1)
    mean_, std_ = train_x_raw.mean(axis=0), train_x_raw.std(axis=0)
    train_x_raw, test_x_raw = (train_x_raw-mean_)/(std_ + 1e-30), (test_x_raw-mean_)/(std_ + 1e-30)
    
    # scaler = StandardScaler()
    # scaler.fit(train_x_raw)
    # train_x_raw, test_x_raw = scaler.transform(train_x_raw), scaler.transform(test_x_raw)
    
    if not flatten:
        train_x_raw, test_x_raw = train_x_raw.reshape(train_x_raw.shape[0], 1, 28 , 28), test_x_raw.reshape(test_x_raw.shape[0],1 ,28 , 28)
    # train_y_raw, test_y_raw = np.eye(10)[train_y_raw], np.eye(10)[test_y_raw]

    # # split test and train dataset
    # trdata, _, trlabel, _  = train_test_split(train_x_raw, train_y_raw,test_size=50000, train_size=numdp,random_state=seed)
    # tedata, telabel = test_x_raw[:1000], test_y_raw[:1000]

    # split test and train dataset
    trdata, _, trlabel, _  = train_test_split(train_x_raw, train_y_raw,test_size=10000, train_size=numdp,random_state=seed)
    tedata, telabel = test_x_raw, test_y_raw
    
    if raw_data:
        train_x_unnorm = train_x_raw * (std_ + 1e-30) + mean_
        train_x_unnorm = np.array(train_x_unnorm)

    trdata, tedata, trlabel, telabel = np.array(trdata), np.array(tedata), np.array(trlabel), np.array(telabel)
    
    # print(np.isnan(np.sum(trdata)), np.isfinite(np.sum(trdata)),np.isnan(np.sum(tedata)), np.isfinite(np.sum(tedata)))
    # raise NotImplementedError
    if raw_data:
        return trdata, tedata, trlabel, telabel, train_x_unnorm
    else:
        return trdata, tedata, trlabel, telabel

def load_fashionmnist(numdp, seed, flatten=True, raw_data=False, **kwargs):
    (train_x_raw, train_y_raw), (test_x_raw, test_y_raw) = tf.keras.datasets.fashion_mnist.load_data()
    
    train_x_raw, test_x_raw = train_x_raw.reshape(train_x_raw.shape[0],-1), test_x_raw.reshape(test_x_raw.shape[0],-1)
    mean_, std_ = train_x_raw.mean(axis=0), train_x_raw.std(axis=0)
    train_x_raw, test_x_raw = (train_x_raw-mean_)/(std_ + 1e-30), (test_x_raw-mean_)/(std_ + 1e-30)
    
    # scaler = StandardScaler()
    # scaler.fit(train_x_raw)
    # train_x_raw, test_x_raw = scaler.transform(train_x_raw), scaler.transform(test_x_raw)
    
    if not flatten:
        train_x_raw, test_x_raw = train_x_raw.reshape(train_x_raw.shape[0], 1, 28 , 28), test_x_raw.reshape(test_x_raw.shape[0],1 ,28 , 28)
    # train_y_raw, test_y_raw = np.eye(10)[train_y_raw], np.eye(10)[test_y_raw]

    # # split test and train dataset
    # trdata, _, trlabel, _  = train_test_split(train_x_raw, train_y_raw,test_size=50000, train_size=numdp,random_state=seed)
    # tedata, telabel = test_x_raw[:1000], test_y_raw[:1000]

    # split test and train dataset
    trdata, trlabel = train_x_raw, train_y_raw
    tedata, telabel = test_x_raw, test_y_raw
    
    if raw_data:
        train_x_unnorm = train_x_raw * (std_ + 1e-30) + mean_
        train_x_unnorm = np.array(train_x_unnorm)

    trdata, tedata, trlabel, telabel = np.array(trdata), np.array(tedata), np.array(trlabel), np.array(telabel)
    # raise NotImplementedError
    if raw_data:
        return trdata, tedata, trlabel, telabel, train_x_unnorm
    else:
        return trdata, tedata, trlabel, telabel
    
def load_data(numdp=5000, dataset="rideshare", seed=43, flatten=True, raw_data=False):
    if raw_data:
        trdata, tedata, trlabel, telabel, train_x_unnorm = eval("load_"+dataset)(numdp=numdp, seed=seed, flatten=flatten, raw_data=raw_data)
    else:
        trdata, tedata, trlabel, telabel = eval("load_"+dataset)(numdp=numdp, seed=seed, flatten=flatten, raw_data=raw_data)
    
    # print(f"Mean label {np.mean(telabel)}, Std label {np.std(telabel)}, Min label {np.min(telabel)}, Max label {np.max(telabel)}")
    # te_dis = pairwise_distances(tedata)
    # tr_dis = pairwise_distances(trdata)    
    # print(f"Test dataset, 5 quantile {np.quantile(te_dis, 0.05)}, 25 quantile {np.quantile(te_dis, 0.25)} median: {np.quantile(te_dis, 0.5)}, 70 quantile {np.quantile(te_dis, 0.7)}, 90 quantile {np.quantile(te_dis, 0.9)}, 95 quantile {np.quantile(te_dis, 0.95)}")
    # print(f"Train dataset, 5 quantile {np.quantile(tr_dis, 0.05)}, 25 quantile {np.quantile(tr_dis, 0.25)} median: {np.quantile(tr_dis, 0.5)}, 70 quantile {np.quantile(tr_dis, 0.7)}, 90 quantile {np.quantile(tr_dis, 0.9)}, 95 quantile {np.quantile(tr_dis, 0.95)}")
    # raise NotImplementedError
    if raw_data:
        return trdata, tedata, trlabel, telabel, train_x_unnorm
    else:
        return trdata, tedata, trlabel, telabel