import os
import re
import h5py
import random
import numpy as np
import pandas as pd
from scipy import stats


class DataLoader:
    predictor_ds_id = 'x_data'
    index_ds_id = 'idx'
    mappability_ds_id = 'mappability'

    def __init__(self, file_path, label_ids, track_file):
        print('Loading data and labels from file {}...'.format(file_path))
        self.file_path = file_path
        self.label_ids = label_ids
        h5f = h5py.File(self.file_path, 'r')
        self.genome_locations = h5f[self.index_ds_id][:]
        if self.mappability_ds_id in h5f.keys():
            print('Extracting only locations with mappability >= 0.7')
            self.idxs = np.where(h5f[self.mappability_ds_id][:] >= 0.7)[0]
        else:
            self.idxs = np.arange(len(self.genome_locations))
        self.labels_lst = [h5f[l][:] for l in self.label_ids]
        self.data = h5f[self.predictor_ds_id][:]
        if len(self.data.shape) < 3:
            self.data = np.expand_dims(self.data, axis=1)

        if track_file is not None:
            self.selected_tracks = self.load_track_selection_file(os.path.join(os.path.dirname(__file__), track_file))
        else:
            self.selected_tracks = np.arange(self.data.shape[2])

        print('Input data is of size: {}'
              .format((len(self.idxs), h5f[self.predictor_ds_id].shape[1], len(self.selected_tracks))))

    @staticmethod
    def tokens_match(strg, search=re.compile(r'[^:0-9]').search):
        return not bool(search(strg))

    def load_track_selection_file(self, file_path):
        with open(file_path, 'r') as f:
            lines = f.readlines()
        track_lst = []
        for i, l in enumerate(lines):
            if l.startswith(('\n', '#')): continue
            l = l.rstrip()  # remove trailing '\n'
            assert self.tokens_match(l), \
                'Expected track selection lines to contain only digits and colons. Found: {} in line #{}.'.format(l, i)

            split_l = l.split(':')
            assert len(split_l) <= 2, \
                'Expected track selection lines to contain only one colon. Found: {} in line #{}.'.format(l, i)
            assert np.all([split_l[j].isdigit() for j in range(len(split_l))]), \
                'Expected to have a number in both sides of the colon. Found: {} in line #{}.'.format(l, i)

            if len(split_l) == 1:
                track_lst.append(int(split_l[0]))
            elif len(split_l) == 2:
                assert int(split_l[0]) < int(split_l[1]), 'Expected x < y in pair x:y. Found: {} in line #{}.'.format(l, i)
                track_lst.extend(np.arange(int(split_l[0]), int(split_l[1])).tolist())

        print('Selected {} tracks: \n{}'.format(len(track_lst), track_lst))
        return track_lst

    def split_randomly(self, train_ratio):
        print('Splitting data at random...')
        tmp_idxs = self.idxs
        split_idx = int(train_ratio * len(tmp_idxs))
        np.random.seed(random.randint(0, 1000000))
        np.random.shuffle(tmp_idxs)
        return np.sort(tmp_idxs[:split_idx]), np.sort(tmp_idxs[split_idx:len(tmp_idxs)])

    def split_by_chromosome(self, train_ratio):
        print('Splitting data by chromosome...')
        chrs = self.genome_locations[:, 0]
        max_chr = np.max(chrs)
        train_idxs_lst = []
        test_idxs_lst = []
        for c in range(max_chr):
            print('Chromosome {}...'.format(c + 1))
            chr_idxs = np.where(chrs == c + 1)[0]
            split_idx = int(train_ratio * len(chr_idxs))
            train_idxs_lst.extend(list(chr_idxs[:split_idx]))
            test_idxs_lst.extend(list(chr_idxs[split_idx:len(chr_idxs)]))
        return np.sort(train_idxs_lst), np.sort(test_idxs_lst)

    def get_datasets(self, split_method='random', train_ratio=0.8):
        if split_method == 'random':
            train_idxs, test_idxs = self.split_randomly(train_ratio)
        elif split_method == 'chr':
            train_idxs, test_idxs = self.split_by_chromosome(train_ratio)
        else:
            raise Exception('Expected split_method to be \'random\' or \'chr\', but found {}'.format(split_method))

        print('Splitting data to train and test sets...')
        train_x = self.data[train_idxs].mean(axis=1)
        train_y_lst = {k: l[train_idxs] for k, l in zip(self.label_ids, self.labels_lst)}
        train_locs = self.genome_locations[train_idxs]

        test_x = self.data[test_idxs].mean(axis=1)
        test_y_lst = {k: l[test_idxs] for k, l in zip(self.label_ids, self.labels_lst)}
        test_locs = self.genome_locations[test_idxs]

        return {'x': train_x, 'y': train_y_lst, 'locs': train_locs}, {'x': test_x, 'y': test_y_lst, 'locs': test_locs}

    def split_kfold_randomly(self, k):
        print('Splitting data to {} parts at random...'.format(k))
        tmp_idxs = self.idxs
        set_size = int(len(tmp_idxs) / k)
        np.random.shuffle(tmp_idxs)
        return np.array([tmp_idxs[i*set_size:(i+1)*set_size] for i in range(k)])

    def split_kfold_by_chromosome(self, k):
        print('Splitting data to {} parts by chromosome...'.format(k))
        chrs = self.genome_locations[:, 0]
        max_chr = np.max(chrs)
        datasets_idxs_lst = [[] for _ in range(k)]
        for c in range(max_chr):
            print('Chromosome {}...'.format(c + 1))
            chr_idxs = np.where(chrs == c + 1)[0]
            set_size = int(len(chr_idxs) / k)
            [datasets_idxs_lst[i].extend(list(chr_idxs[i*set_size:(i+1)*set_size])) for i in range(k)]
        return np.array([np.sort(datasets_idxs_lst[i]) for i in range(k)])

    def get_kfold_datasets(self, split_method='random', k=5):
        if split_method == 'random':
            k_idxs = self.split_kfold_randomly(k)
        elif split_method == 'chr':
            k_idxs = self.split_kfold_by_chromosome(k)
        else:
            raise Exception('Expected split_method to be \'random\' or \'chr\', but found {}'.format(split_method))

        print('Splitting data to train and test sets...')
        k_x = [self.data[k_idxs[i]].mean(axis=1) for i in range(k)]
        k_y_lst = [{key: l[k_idxs[i]] for key, l in zip(self.label_ids, self.labels_lst)} for i in range(k)]
        k_locs = [self.genome_locations[k_idxs[i]] for i in range(k)]

        return [{'x': k_x[i], 'y': k_y_lst[i], 'locs': k_locs[i]} for i in range(k)]
