import os
import sys
import zipfile
import numpy as np
import pandas as pd

import h5py

from scipy.io import arff

import torch
from torch import nn
import torchvision
from torchvision import transforms, datasets
from torch import distributions
from torch.utils.data import TensorDataset, DataLoader

from sklearn.covariance import LedoitWolf
from sklearn.preprocessing import MinMaxScaler, StandardScaler, QuantileTransformer
from sklearn.model_selection import KFold, train_test_split


from tqdm import tqdm
from joblib import Parallel, delayed, load, dump

import ssl


def path_to_matrix(path):
    if path == 'news':
        df = pd.read_csv('./data/OnlineNewsPopularity/OnlineNewsPopularity.csv')
        matrix = df.values[:,1:]
        return matrix.astype('float32')
    elif path == 'letter':
        df = pd.read_csv('./data/letter-recognition.data')
        matrix = df.values[:,1:]
        return matrix.astype('float32')
    elif path == 'concrete':
        df = pd.read_excel("./data/Concrete_Data.xls")
        matrix = df.values
        return matrix.astype('float32')
    elif path == 'air':
        df = pd.read_excel("./data/AirQualityUCI.xlsx", engine='openpyxl').iloc[:,2:-2]
        df = df.dropna()
        matrix = df.values
        return matrix.astype('float32')
    elif path == 'review':
        df = pd.read_csv("./data/google_review_ratings.csv").iloc[:,1:-1]
        df.iloc[2712, 10] = '2'
        df = df.dropna()
        return df.values.astype('float32')
    elif path == 'credit': 
        df = pd.read_excel("./data/default of credit card clients.xls", header=1)
        matrix = df.values[:,1:-1]
        return matrix.astype('float32')
    elif path == 'energy': # first column is time stamp
        df = pd.read_csv("./data/energydata_complete.csv").iloc[:,1:]
        matrix = df.values
        return matrix.astype('float32')
    elif path == 'song':
        df = pd.read_csv("./data/YearPredictionMSD.txt", header=None).iloc[:463715,1:]
        df = df.sample(frac=0.2, random_state=1000)
        return df.values.astype('float32')
    elif path == 'wine':
        df_red = pd.read_csv('./data/winequality-red.csv', sep=';')
        df_white = pd.read_csv('./data/winequality-white.csv', sep=';')
        df = pd.concat([df_red, df_white], axis=0)
        return df.values.astype('float32')
    elif path == 'ctg':
        df = pd.read_excel("./data/CTG.xls", sheet_name='Data', header=1, usecols=list(range(10, 31)), nrows=2126)
        matrix = df.values
        return matrix.astype('float32')
    elif path == 'mnist':
        mnist_trainset = datasets.MNIST(root='./data', train=True, download=True)
        mnist_testset = datasets.MNIST(root='./data', train=False, download=True)
        mnist_train = mnist_trainset.data.reshape((-1, 28*28)) / 255
        mnist_test = mnist_testset.data.reshape((-1, 28*28)) / 255
        return mnist_train, mnist_test, (1, 28, 28)
    elif path == 'cifar0':
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        trainset = torch.from_numpy(trainset.data / 255).permute(0, 3, 1, 2)
        trainset = trainset[:,0,:,:].reshape(-1, 32*32).float()
        
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
        testset = torch.from_numpy(testset.data / 255).permute(0, 3, 1, 2)
        testset = testset[:,0,:,:].reshape(-1, 32*32).float()
        return trainset, testset, (1, 32, 32)
    elif path == 'cifar1':
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        trainset = torch.from_numpy(trainset.data / 255).permute(0, 3, 1, 2)
        trainset = trainset[:,1,:,:].reshape(-1, 32*32).float()
        
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
        testset = torch.from_numpy(testset.data / 255).permute(0, 3, 1, 2)
        testset = testset[:,1,:,:].reshape(-1, 32*32).float()
        return trainset, testset, (1, 32, 32)
    elif path == 'cifar2':
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        trainset = torch.from_numpy(trainset.data / 255).permute(0, 3, 1, 2)
        trainset = trainset[:,2,:,:].reshape(-1, 32*32).float()
        
        testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
        testset = torch.from_numpy(testset.data / 255).permute(0, 3, 1, 2)
        testset = testset[:,2,:,:].reshape(-1, 32*32).float()
        return trainset, testset, (1, 32, 32)
    else:
        print("Not a valid dataset\n\n")
        sys.exit()
        
#####################################################################
#
#                    For tabular data
#
#####################################################################

def preprocess(data):
    scalar = MinMaxScaler()
    # scalar = StandardScaler()
    data_trans = scalar.fit_transform(data)
    # return data_trans, scalar.data_max_, scalar.data_min_
    return data_trans

def create_k_fold(matrix, fold_id):
    kf = KFold(n_splits = 5, shuffle=False)
    for i, (train_index, test_index) in enumerate(kf.split(matrix)):
        if i == fold_id:
            train, test = matrix[train_index], matrix[test_index]
            break
    return train, test


def create_tabular_mask(matrix, drp_percent, how='MCAR', seed=0):
    rg = np.random.RandomState(seed)
    if how == 'MCAR':
        return rg.binomial(1, drp_percent, matrix.shape)
    elif how == 'MAR':
        # get the first 70% features
        n_features = matrix.shape[1]
        partial_sum = matrix[:,:int(n_features*0.7)].sum(1)
        partial_sum = (partial_sum - partial_sum.min()) / (partial_sum.max() - partial_sum.min())
        partial_sum = partial_sum * (2 + 2) - 2
        # decide if missing
        miss_p = 1. / (1. + np.exp(-partial_sum))
        is_miss = rg.binomial(1, miss_p.reshape(-1, ), len(miss_p))
        print(is_miss.mean())
        # create missing indicator
        mask = np.zeros(matrix.shape).astype(int)
        for row in range(mask.shape[0]):
            if is_miss[row]:
                mask[row, int(n_features*0.7):] = 1
        return mask
    else:
        print("Not a valid missing pattern\n\n")
        sys.exit()


def fill_missingness(matrix, mask):
    matrix_filled = matrix.copy()
    for i in range(matrix.shape[1]):
        n_miss = mask[:,i].sum()
        random_entries = np.random.choice(matrix[:,i][~mask[:,i].astype(bool)], n_miss, replace=True)
        matrix_filled[:,i][mask[:,i].astype(bool)] = random_entries

    return matrix_filled


def infer_imputation(batch_EM, nf, data, mask, args):
    ds = TensorDataset(data, mask)
    dl = DataLoader(ds, shuffle=False, batch_size=args.batch_size_test)

    nf.eval()
    with torch.no_grad():
        lst = []
        for data_rows, mask_rows in dl:
            data_rows = data_rows.cuda()
            mask_rows = mask_rows.cuda()

            z = nf(data_rows)[0]
            z_hat = batch_EM.complete_gpu(z, mask_rows, mode='test')
            x_hat = nf.inverse(z_hat)
            x_hat = x_hat.cpu()
            lst.append(torch.clamp(x_hat, 0, 1))

    return torch.cat(lst, 0)

def update_imputation(dataset, batch_EM, nf, args):
    dl = DataLoader(dataset, shuffle=False, batch_size=args.batch_size_test)
    batch_lst = []
    with torch.no_grad():
        for x_dot, x_origin, mask in dl:
            x_dot = x_dot.cuda()
            mask = mask.cuda()

            z = nf(x_dot)[0]
            z_hat = batch_EM.complete_gpu(z, mask, mode='test')
            x_hat = nf.inverse(z_hat)
            
            x_hat = torch.clamp(x_hat, 0, 1)
            x_hat = x_hat * mask + x_dot * (1-mask)
            x_hat = x_hat.cpu()
            
            batch_lst.append(x_hat)
    dataset.image = torch.cat(batch_lst, 0)


#####################################################################
#
#                    For imgae data
#
#####################################################################

# follow the logic in MCFlow
def create_img_masks(drp_percent, img_shp, num_tr, num_te, seed):
    mask_train = []
    mask_test = []

    num_channels = img_shp[0]
    
    rng = np.random.RandomState(seed)
    for idx in range(num_tr):
        sample = []
        for r_idx in range(img_shp[1]):
            for c_idx in range(img_shp[2]):
                if rng.uniform() < drp_percent:
                    sample.append(1)
                else:
                    sample.append(0)
        mask_train.append(np.asarray(sample))

    for idx in range(num_te):
        sample = []
        for r_idx in range(img_shp[1]):
            for c_idx in range(img_shp[2]):
                if rng.uniform() < drp_percent:
                    sample.append(1)
                else:
                    sample.append(0)
        mask_test.append(np.asarray(sample))
    return torch.from_numpy(np.array(mask_train)), torch.from_numpy(np.array(mask_test))


def fill_img_missingness(tr_data, te_data, mask_train, mask_test, shape):
    train = initialize_nneighbor_radnommat(tr_data, mask_train, shape)
    test = initialize_nneighbor_radnommat(te_data, mask_test, shape)
    return train, test


def nn_impute(imgs, masks):

    updater_lst = []
    for i in range(imgs.shape[0]):
        img, mask = imgs[i], masks[i]
        missing_idx = np.where(mask == 1.)
        updater = img.copy()
        for x, y in zip(missing_idx[0], missing_idx[1]):
            layer = 1
            neighbors = []
            while len(neighbors) == 0:
                corners = [(x + layer, y + layer),
                           (x - layer, y + layer),
                           (x + layer, y - layer),
                           (x - layer, y - layer)]

                #row check -- need to check if the mask is zero there or that the data exists
                for _row in range(corners[1][0], corners[0][0]):
                    if _row >= 0 and _row < mask.shape[0] and corners[0][1] >=0 and corners[0][1] < mask.shape[0]:
                        if mask[_row][corners[0][1]] == 0:
                            neighbors.append((_row, corners[0][1]))

                #column check
                for _column in range(corners[2][1], corners[0][1]):
                    if _column >= 0 and _column < mask.shape[1] and corners[0][0] >=0 and corners[0][0] < mask.shape[0]:
                        if mask[corners[0][0]][_column] == 0:
                            neighbors.append((corners[0][0], _column))

                #row check
                for _row in range(corners[3][0], corners[2][0]):
                    if _row >= 0 and _row < mask.shape[0] and corners[2][1] >=0 and corners[2][1] < mask.shape[0]:
                        if mask[_row][corners[2][1]] == 0:
                            neighbors.append((_row, corners[2][1]))

                #column check
                for _column in range(corners[3][1], corners[1][1]):
                    if _column >= 0 and _column < mask.shape[1] and corners[1][0] >=0 and corners[1][0] < mask.shape[0]:
                        if mask[corners[1][0]][_column] == 0:
                            neighbors.append((corners[1][0], _column))

                layer += 1

            loc = np.random.randint(len(neighbors))
            if img.shape[0] == 1:
                updater[0][x, y] = img[0][neighbors[loc][0], neighbors[loc][1]]
            else:
                updater[0][x, y] = img[0][neighbors[loc][0], neighbors[loc][1]]
                updater[1][x, y] = img[1][neighbors[loc][0], neighbors[loc][1]]
                updater[2][x, y] = img[2][neighbors[loc][0], neighbors[loc][1]]

        updater_lst.append(updater.reshape((-1,)))

    return np.array(updater_lst)

def initialize_nneighbor_radnommat(data, mask, shape, n_jobs=20):

    data_lst = np.array_split(data.reshape((-1, shape[0], shape[1], shape[2])).numpy(), n_jobs, 0)
    mask_lst = np.array_split(mask.reshape((-1, shape[1], shape[2])).numpy(), n_jobs, 0)
    res = Parallel(n_jobs=n_jobs)(delayed(nn_impute)(data_part, mask_part) for data_part, mask_part in zip(data_lst, mask_lst))
    return torch.from_numpy(np.concatenate(res, 0))