import os
import re
import numpy as np
import pandas as pd
from scipy import stats
import math

from mut_dataset import *

class BaseDatasetGenerator:
    predictor_ds_id = 'x_data'
    index_ds_id = 'idx'
    mappability_ds_id = 'mappability'

    def __init__(self, file_path, label_ids, map_thresh, heldout_ratio, track_file=None):
        print('Loading data and labels from file {}...'.format(file_path))
        self.file_path = file_path
        self.label_ids = label_ids
        h5f = h5py.File(self.file_path, 'r')
        self.genome_locations = h5f[self.index_ds_id][:]
        if self.mappability_ds_id in h5f.keys():
            self.idxs = np.where(h5f[self.mappability_ds_id][:] >= map_thresh)[0]
            self.below_mapp = np.where(h5f[self.mappability_ds_id][:] < map_thresh)[0]
        else:
            self.idxs = np.arange(len(self.genome_locations))
            self.below_mapp = np.zeros(0)
        self.labels_lst = [h5f[l][:] for l in self.label_ids]
        
        if track_file is not None:
            self.selected_tracks = self.load_track_selection_file(os.path.join(os.path.dirname(__file__), track_file))
        else:
            self.selected_tracks = np.arange(h5f[self.predictor_ds_id].shape[2])
        print('Input data is of size: {}'
              .format((len(self.idxs), h5f[self.predictor_ds_id].shape[1], len(self.selected_tracks))))

    @staticmethod
    def tokens_match(strg, search=re.compile(r'[^:0-9]').search):
        return not bool(search(strg))

    def load_track_selection_file(self, file_path):
        with open(file_path, 'r') as f:
            lines = f.readlines()
        track_lst = []
        for i, l in enumerate(lines):
            if l.startswith(('\n', '#')): continue
            l = l.rstrip()  # remove trailing '\n'
            assert self.tokens_match(l), \
                'Expected track selection lines to contain only digits and colons. Found: {} in line #{}.'.format(l, i)

            split_l = l.split(':')
            assert len(split_l) <= 2, \
                'Expected track selection lines to contain only one colon. Found: {} in line #{}.'.format(l, i)
            assert np.all([split_l[j].isdigit() for j in range(len(split_l))]), \
                'Expected to have a number in both sides of the colon. Found: {} in line #{}.'.format(l, i)

            if len(split_l) == 1:
                track_lst.append(int(split_l[0]))
            elif len(split_l) == 2:
                assert int(split_l[0]) < int(split_l[1]), 'Expected x < y in pair x:y. Found: {} in line #{}.'.format(l, i)
                track_lst.extend(np.arange(int(split_l[0]), int(split_l[1])).tolist())

        print('Selected {} tracks: \n{}'.format(len(track_lst), track_lst))
        return track_lst

    def get_heldout_dataset(self):
        return SimpleDatasetFromH5(self.file_path,
                                     self.label_ids,
                                     self.heldout_idxs,
                                     self.genome_locations[self.heldout_idxs],
                                     self.selected_tracks,
                                     self.predictor_ds_id)


class DatasetGenerator(BaseDatasetGenerator):
    predictor_ds_id = 'x_data'
    index_ds_id = 'idx'

    def __init__(self, file_path, label_ids, map_thresh, heldout_ratio, heldout_file=None, track_file=None):
        super(DatasetGenerator, self).__init__(file_path, label_ids, map_thresh, heldout_ratio, track_file)

        if heldout_file is not None:
            print('Using predefined held-out samples from {}'.format(heldout_file))
            self.heldout_idxs = self.extract_heldout_set(heldout_file)
        else:
            print('Extracting held-out set at random...')
            del_idxs = np.random.choice(np.arange(len(self.idxs)), size=int(heldout_ratio * len(self.idxs)), replace=False)
            self.heldout_idxs = self.idxs[del_idxs]
            self.idxs = np.delete(self.idxs, del_idxs)
        assert len(np.intersect1d(self.idxs, self.heldout_idxs)) == 0, \
            'Found {} intersecting elements with the held-out set'.format(len(np.intersect1d(self.idxs, self.heldout_idxs)))

    def extract_heldout_set(self, heldout_path):
        cols=['CHROM', 'START', 'END', 'Y_TRUE', 'Y_PRED', 'STD', 'PVAL', 'RANK']
        with open(heldout_path, 'r') as f:
            heldout_chr_loc_df = pd.DataFrame([l.split('\t') for l in f.read().split('\n')], columns=cols) \
                                   .drop(0).reset_index(drop=True).dropna(axis=0, how='any')
        origin_chr_loc_df = pd.DataFrame(self.genome_locations, columns=cols[:3])
        origin_chr_loc_df['Y_TRUE'] = self.labels_lst[0]
        set_origin_idxs = []
        for i in heldout_chr_loc_df.index:
            row = heldout_chr_loc_df.loc[i]
            origin_idx = origin_chr_loc_df.index[np.where((origin_chr_loc_df.CHROM == int(row.CHROM)) & \
                                                          (origin_chr_loc_df.START == int(row.START)))[0]]
            assert len(origin_idx) == 1, 'Found {} matches for location {}'.format(len(origin_idx), row)
            origin_idx = origin_idx[0]
            assert float(row.Y_TRUE) == float(origin_chr_loc_df.Y_TRUE[origin_idx]), \
                'Mismatch of ground truth mutation count. Expected {}, but found {}.' \
                .format(row.Y_TRUE, origin_chr_loc_df.Y_TRUE[origin_idx])
            assert len(np.where(self.idxs == origin_idx)[0]) != 0, \
                'Expected the following to be in the data set, but wasn\'t found \n{}'.format(row)
            assert len(np.where(self.idxs == origin_idx)[0]) == 1, 'Expected index {} to appear once, but found {}' \
                .format(origin_idx, len(np.where(self.idxs == origin_idx)[0]))
            self.idxs = np.delete(self.idxs, np.where(self.idxs == origin_idx)[0])
            set_origin_idxs.append(origin_idx)
        print('Heldout {} windows.'.format(len(set_origin_idxs)))
        return set_origin_idxs

    def get_datasets(self, split_method='random', train_ratio=0.8, resample=0):
        if split_method == 'random':
            train_idxs, test_idxs = self.split_randomly(train_ratio)
        elif split_method == 'chr':
            train_idxs, test_idxs = self.split_by_chromosome(train_ratio)
        else:
            raise Exception('Expected split_method to be \'random\' or \'chr\', but found {}'.format(split_method))
            #if resample > 0:
            #    train_idxs = self.resample_to_uniform(train_idxs, resample)

        train_ds = SimpleDatasetFromH5(self.file_path,
                                         self.label_ids,
                                         train_idxs,
                                         self.genome_locations[train_idxs],
                                         self.selected_tracks,
                                         self.predictor_ds_id)
        test_ds = SimpleDatasetFromH5(self.file_path,
                                        self.label_ids,
                                        test_idxs,
                                        self.genome_locations[test_idxs],
                                        self.selected_tracks,
                                        self.predictor_ds_id)

        return train_ds, test_ds

    def split_randomly(self, train_ratio):
        print('Splitting data at random...')
        tmp_idxs = self.idxs
        split_idx = int(train_ratio * len(tmp_idxs))
        np.random.shuffle(tmp_idxs)
        return tmp_idxs[:split_idx], tmp_idxs[split_idx:len(tmp_idxs)]

    def split_by_chromosome(self, train_ratio):
        print('Splitting data by chromosome...')
        chrs = self.genome_locations[:, 0]
        max_chr = np.max(chrs)
        train_idxs_lst = []
        test_idxs_lst = []
        for c in range(max_chr):
            print('Chromosome {}...'.format(c + 1))
            chr_idxs = np.where(chrs == c + 1)[0]
            split_idx = int(train_ratio * len(chr_idxs))
            train_idxs_lst.extend(list(chr_idxs[:split_idx]))
            test_idxs_lst.extend(list(chr_idxs[split_idx:len(chr_idxs)]))
        return np.sort(train_idxs_lst), np.sort(test_idxs_lst)

    '''
    def resample_to_uniform(self, idxs, max_sample_count):
        assert len(self.labels_lst), 'Resampling for multiple targets wasn\'t implemanted.'
        labels = self.labels_lst[0][idxs]
        mode_val, mode_count = stats.mode(labels, axis=None)
        max_val = np.max(labels)
        for 
    '''

    def resample_high_count_uniformly(self, idxs, size, sample_thresh, confidence=0.997):
        assert len(self.labels_lst), 'Resampling for multiple targets wasn\'t implemanted.'
        labels = self.labels_lst[0][idxs]
        mode_val, mode_count = stats.mode(labels, axis=None)
        max_val = np.max(labels)
        hist_range = (0, np.percentile(labels, confidence, interpolation='lower'))
        counts, bins = np.histogram(y, hist_range[1], hist_range)

        # create bins count dictionary
        max_bin = hist_range[1]
        bins_dict = {b:np.zeros(c) for b, c in zip(bins[:-1], counts[:])}
        bins_dict[bins[-1]] = np.zeros(np.sum(labels.astype(int) > max_bin))
        bin_idxs = np.zeros(len(bins_dict)).astype(int)
        for i, m in enumerate(labels.astype(int)):
            if m == max_bin: m = max_bin - 1
            elif m > max_bin: m = max_bin
            bins_dict[m][bin_idxs[m]] = i
            bin_idxs[m] += 1

        # generate dataset
        counts_dict = {i:0 for i in range(len(y))}
        labels = np.zeros(size)
        predictors = np.zeros((size, x.shape[1], x.shape[2]))
        bins_num = len(bins_dict)
        class_mask = np.ones(bins_num).astype(int) / bins_num
        for i in range(size):
            if i % 1e4 == 0: print('Resampled {}/{} samples'.format(i, size))
            c = np.random.choice(len(bins_dict), 1, p=class_mask)[0]
            s = np.random.choice(bins_dict[c], 1)[0].astype(int)
            counts_dict[s] +=1
            labels[i] = y[s]
            predictors[i] = x[s]
            if counts_dict[s] >= thresh: 
                bins_dict[c] = np.delete(bins_dict[c], np.where(bins_dict[c] == s)[0][0])
                if len(bins_dict[c]) == 0:
                    print('All samples of class {} were maximaly resampled {} times'.format(c, thresh))
                    class_mask[c] = 0  # remove class
                    if np.sum(class_mask) == 0:
                        print('Generated {} samples instead of {} since all resamples were maximized'.format(i, size))
                        break
                    bins_num -= 1
                    class_mask *= (bins_num + 1) / bins_num  #normalize to 1
        return predictors, [labels]


class KFoldDatasetGenerator(BaseDatasetGenerator):
    predictor_ds_id = 'x_data'
    index_ds_id = 'idx'

    def __init__(self, file_path, label_ids, k, map_thresh, split_method='random', resample=0, track_file=None):
        super(KFoldDatasetGenerator, self).__init__(file_path, label_ids, map_thresh, track_file)
        self.k = k
        self.divide_datasets(split_method, k, resample)

    def divide_datasets(self, split_method, kfold, resample=0):
        assert type(kfold) is int, 'Expected type of hyperparameter \'kfold\' to be int but found {}'.format(type(kfold))
        if split_method == 'random':
            self.ds_idxs = self.split_randomly(kfold)
        elif split_method == 'chr':
            self.ds_idxs = self.split_by_chromosome(kfold)
        else:
            raise Exception('Expected split_method to be \'random\' or \'chr\', but found {}'.format(split_method))
        #if resample > 0:
        #    train_idxs = self.resample_to_uniform(train_idxs, resample)

    def split_randomly(self, k):
        print('Splitting data to {} parts at random...'.format(k))
        tmp_idxs = self.idxs
        set_size = len(tmp_idxs) / k
        np.random.shuffle(tmp_idxs)
        return np.array([tmp_idxs[math.floor(i*set_size):math.floor((i+1)*set_size)] for i in range(k)])


    def split_by_chromosome(self, k):
        print('Splitting data to {} parts by chromosome...'.format(k))
        chrs = self.genome_locations[:, 0]
        max_chr = np.max(chrs)
        datasets_idxs_lst = [[] for _ in range(k)]
        for c in range(max_chr):
            print('Chromosome {}...'.format(c + 1))
            chr_idxs = np.where(chrs == c + 1)[0]
            set_size = int(len(chr_idxs) / k)
            [datasets_idxs_lst[i].extend(list(chr_idxs[i*set_size:(i+1)*set_size])) for i in range(k)]
        return np.array([np.sort(datasets_idxs_lst[i]) for i in range(k)])

    def get_datasets(self, fold_idx):
        test_idxs = self.ds_idxs[fold_idx]
        train_idxs = np.concatenate(self.ds_idxs[np.delete(np.arange(self.k), fold_idx)])
        test_ds = SimpleDatasetFromH5(self.file_path, self.label_ids, test_idxs, self.genome_locations[test_idxs], self.selected_tracks, self.predictor_ds_id)
        train_ds = SimpleDatasetFromH5(self.file_path, self.label_ids, train_idxs, self.genome_locations[train_idxs], self.selected_tracks, self.predictor_ds_id)
        return train_ds, test_ds

    def get_below_mapp(self):
        return SimpleDatasetFromH5(self.file_path, self.label_ids, self.below_mapp, self.genome_locations[self.below_mapp], self.selected_tracks, self.predictor_ds_id)
