import h5py
import torch
import numpy as np
from torch.utils.data import Dataset

from sklearn import decomposition
from sklearn.neighbors import NearestNeighbors


class SimpleDataset(Dataset):

    def __init__(self, data, labels_lst):
        self.data = data
        self.labels_lst = [lbl for lbl in labels_lst]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        X = torch.tensor(self.data[idx]).float()
        y_lst = [torch.tensor(l[idx]).float() for l in self.labels_lst]
        return X, y_lst

    def get_data_shape(self):
        return self.data.shape

    def get_train_set_length(self, train_ratio):
        return int(train_ratio * self.data.shape[0])


class BaseDatasetFromH5(Dataset):
    def __init__(self, preprocessed_idxs, chr_locations, selected_tracks):
        self.preprocessed_idxs = preprocessed_idxs
        self.chr_locations = chr_locations
        self.selected_tracks = selected_tracks

    def __len__(self):
        return len(self.preprocessed_idxs)

    def get_set_indices(self):
        return self.preprocessed_idxs

    def get_chromosome_locations(self):
        return self.chr_locations


class SimpleDatasetFromH5(BaseDatasetFromH5):
    def __init__(self, h5_file, label_ids, preprocessed_idxs, chr_locations, selected_tracks, data_id):
        super(SimpleDatasetFromH5, self).__init__(preprocessed_idxs, chr_locations, selected_tracks)
        print('Loading data and labels from file {}...'.format(h5_file))
        with h5py.File(h5_file, 'r') as h5f:
            self.data = torch.tensor(h5f[data_id][np.sort(self.preprocessed_idxs)]).float()
            self.labels_lst = [torch.tensor(h5f[l][np.sort(self.preprocessed_idxs)]).float() for l in label_ids]
        print('Loaded input data of size: {}'.format(self.data.shape))
        print('Computing empirical variance...')
        pca = decomposition.PCA(n_components=60)
        self.pca_data = pca.fit_transform(self.data.mean(axis=1))
        nbrs = NearestNeighbors(n_neighbors=500, algorithm='ball_tree').fit(self.pca_data)
        distances, indices = nbrs.kneighbors(self.pca_data)
        self.estd_lst = [lbl[indices].std(axis=1) for lbl in self.labels_lst]
        self.emean_lst = [lbl[indices].mean(axis=1) for lbl in self.labels_lst]

    def __getitem__(self, idx):
        X = self.data[idx, :, self.selected_tracks]
        y1_lst = [l[idx] for l in self.labels_lst]
        y2_lst = [e[idx] for e in self.estd_lst]
        return X, y1_lst, y2_lst 

    def get_data_shape(self):
        return self.data.shape


class LazyLoadDatasetFromH5(BaseDatasetFromH5):
    def __init__(self, h5_file, label_ids, preprocessed_idxs, chr_locations, selected_tracks, data_id):
        super(LazyLoadDatasetFromH5, self).__init__(preprocessed_idxs, chr_locations, selected_tracks)
        self.h5_file = h5_file
        self.label_ids = label_ids
        self.data_id = data_id

    def __getitem__(self, idx):
        data_idx = self.preprocessed_idxs[idx]
        with h5py.File(self.h5_file,'r') as db:
            X = torch.tensor(db[self.data_id][data_idx, :, self.selected_tracks]).float()
            y_lst = [torch.tensor(db[l][data_idx]).float() for l in self.label_ids]
        return X, y_lst

    def get_data_shape(self):
        with h5py.File(self.h5_file,'r') as db:
            return (len(self.preprocessed_idxs), db[self.data_id].shape[1], len(self.selected_tracks))
