import h5py
import torch
import numpy as np
from torch.utils.data import Dataset

from sklearn import decomposition
from sklearn.neighbors import NearestNeighbors

import scipy.stats
from sklearn import linear_model

class SimpleDataset(Dataset):

    def __init__(self, data, labels_lst):
        self.data = data
        self.labels_lst = [lbl for lbl in labels_lst]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        X = torch.tensor(self.data[idx]).float()
        y_lst = [torch.tensor(l[idx]).float() for l in self.labels_lst]
        return X, y_lst

    def get_data_shape(self):
        return self.data.shape

    def get_train_set_length(self, train_ratio):
        return int(train_ratio * self.data.shape[0])


class BaseDatasetFromH5(Dataset):
    def __init__(self, preprocessed_idxs, chr_locations, selected_tracks):
        self.preprocessed_idxs = preprocessed_idxs
        self.chr_locations = chr_locations
        self.selected_tracks = selected_tracks

    def __len__(self):
        return len(self.preprocessed_idxs)

    def get_set_indices(self):
        return self.preprocessed_idxs

    def get_chromosome_locations(self):
        return self.chr_locations

def knn_simulation(x, y, seed=None):
    print('linear regression')
    regr = linear_model.LinearRegression()
    regr.fit(x, y)
    x_weight = x * regr.coef_
    # x_sc = preprocessing.StandardScaler().fit_transform(x_weight)

    print('pca')
    pca = decomposition.PCA(n_components=50)
    x_pca = pca.fit_transform(x_weight)
    var_expl = np.sum(pca.explained_variance_ratio_)
    print("\tvariance explained: {:03f}".format(var_expl))
    
    print('knn')
    nbrs = NearestNeighbors(n_neighbors=500, algorithm='ball_tree').fit(x_pca)
    distances, indices = nbrs.kneighbors(x_pca)
    
    y_mean = y[indices]
    emeans_mean = np.mean(y_mean, axis=1)
    estds_mean = np.std(y_mean, axis=1)
    
    print('simulating')
    if seed:
        np.random.seed(seed)
        
    alpha = emeans_mean**2 / estds_mean**2
    theta = estds_mean**2 / emeans_mean
    p = 1 / (theta + 1)
    ysim = scipy.stats.nbinom.rvs(n=alpha, p=p)
    
    return torch.tensor(emeans_mean).float(), torch.tensor(estds_mean).float(), torch.tensor(ysim).float()
    

class SimpleDatasetFromH5(BaseDatasetFromH5):
    def __init__(self, h5_file, label_ids, preprocessed_idxs, chr_locations, selected_tracks, data_id, is_test):
        super(SimpleDatasetFromH5, self).__init__(preprocessed_idxs, chr_locations, selected_tracks)
        print('Loading data and labels from file {}...'.format(h5_file))
        self.is_test = is_test
        with h5py.File(h5_file, 'r') as h5f:
            self.data = torch.tensor(h5f[data_id][np.sort(self.preprocessed_idxs)]).float()
            y_lst = [h5f[l][np.sort(self.preprocessed_idxs)] for l in label_ids]
        print('Loaded input data of size: {}'.format(self.data.shape))
        print('Simulating data...')
        self.emeans, self.estds, self.ysim = knn_simulation(self.data.mean(axis=1), y_lst[0])

    def __getitem__(self, idx):
        X = self.data[idx, :, self.selected_tracks]
        if self.is_test:
            y_lst = [self.emeans[idx]]
        else:
            y_lst = [self.ysim[idx]]
        return X, y_lst

    def get_data_shape(self):
        return self.data.shape

    def get_stds(self):
        return self.estds.numpy()


class LazyLoadDatasetFromH5(BaseDatasetFromH5):
    def __init__(self, h5_file, label_ids, preprocessed_idxs, chr_locations, selected_tracks, data_id):
        super(LazyLoadDatasetFromH5, self).__init__(preprocessed_idxs, chr_locations, selected_tracks)
        self.h5_file = h5_file
        self.label_ids = label_ids
        self.data_id = data_id

    def __getitem__(self, idx):
        data_idx = self.preprocessed_idxs[idx]
        with h5py.File(self.h5_file,'r') as db:
            X = torch.tensor(db[self.data_id][data_idx, :, self.selected_tracks]).float()
            y_lst = [torch.tensor(db[l][data_idx]).float() for l in self.label_ids]
        return X, y_lst

    def get_data_shape(self):
        with h5py.File(self.h5_file,'r') as db:
            return (len(self.preprocessed_idxs), db[self.data_id].shape[1], len(self.selected_tracks))
