import os
import numpy as np

from cb_evaluation_api import class_evaluation


class DataPack:
    def __init__(self):
        # basic content
        self.data = None
        self.label = None

        # used to locate the data
        self.loc = None
        self.patient_list = None


class DataHandler:
    def __init__(
            self,
            database_save_dir='/data/CL_database/',     # The path of database file holder
            data_name='SEEG',                               # The name of database
            exp_id=None,
            patient_list=None,                              # The list of patient group
            noise_ratio=None,                               # The noise ratio for the original labels
            window_time=1,          # The unit is second
            slide_time=0.5,         # The unit is second
            num_level=5,            # The number of levels
    ):
        assert data_name in ['SEEG', 'fNIRS_2', 'Sleep']
        self.database_save_dir = os.path.join(database_save_dir, data_name)
        self.data_name = data_name
        self.exp_id = exp_id
        self.patient_list = patient_list
        self.noise_ratio = noise_ratio

        if data_name in ['SEEG', 'Sleep']:
            if data_name == 'SEEG':
                self.sample_rate = 250
            else:
                self.sample_rate = 100
            self.window_len = int(window_time * self.sample_rate)
            self.slide_len = int(slide_time * self.sample_rate)
        else:
            self.sample_rate = None
            self.window_len = window_time
            self.slide_len = slide_time

        self.num_level = num_level

    def obtain_database_dir(self, pa, level, clean_label=False, model_label=False, random_noise=False):
        if self.data_name == 'SEEG':
            pre_fix = f'{pa}/level{level}_sample'
        elif clean_label and random_noise:
            pre_fix = f'0/s{pa}_level{level}_sample'
        else:
            pre_fix = f'{int(self.noise_ratio * 100)}/s{pa}_level{level}_sample'

        if clean_label:
            return os.path.join(self.database_save_dir, pre_fix + '_clean.npz')
        elif model_label:
            return os.path.join(self.database_save_dir, pre_fix + f'_model_{self.exp_id}.npz')
        elif random_noise:
            return os.path.join(self.database_save_dir, pre_fix + '_random.npz')
        else:
            return os.path.join(self.database_save_dir, pre_fix + '.npz')

    def __get_database__(self, segment=True, clean_label=False, model_label=False, random_noise=False):
        data_pack = DataPack()
        data_pack.data = [[] for _ in range(self.num_level)]
        data_pack.label = [[] for _ in range(self.num_level)]
        data_pack.loc = [[] for _ in range(self.num_level)]
        data_pack.patient_list = self.patient_list

        for pa in self.patient_list:
            for level in range(self.num_level):
                load_path = self.obtain_database_dir(pa, level, clean_label, model_label, random_noise)
                print(f'Loading the labels from: {load_path}')
                all_data = np.load(load_path)

                data_pack.data[level].append(all_data['data'])
                data_pack.label[level].append(all_data['label'])
                data_pack.loc[level].append(all_data['loc'])

                if data_pack.data[level][-1].ndim == 2:
                    data_pack.data[level][-1] = np.expand_dims(data_pack.data[level][-1], axis=-1)

        if random_noise and not clean_label:
            data_pack.data = np.array(data_pack.data).reshape([self.num_level, -1, *data_pack.data[0][0].shape[-3:]])
            # num_level x seg_big_num x seg_small_num
            data_pack.label = np.array(data_pack.label).reshape([self.num_level, -1, data_pack.label[0][0].shape[-1]])
            data_pack.loc = np.array(data_pack.loc).reshape([self.num_level, -1, *data_pack.loc[0][0].shape[-2:]])

            print('Total BIG Segment Number for', self.patient_list, ':',
                  data_pack.data.shape[0] * data_pack.data.shape[1])
            print('SMALL Segment Number for each big segment:', data_pack.data.shape[-3])
            return data_pack

        # num_level x (pa * sample_num) x length (x n_features)
        data_pack.data = np.array(data_pack.data).reshape([self.num_level, -1, *data_pack.data[0][0].shape[-2:]])
        data_pack.label = np.array(data_pack.label).reshape([self.num_level, -1, data_pack.label[0][0].shape[-1]])
        data_pack.loc = np.array(data_pack.loc).reshape([self.num_level, -1, data_pack.loc[0][0].shape[-1]])

        print('Total BIG Segment Number for', self.patient_list, ':', data_pack.data.shape[0] * data_pack.data.shape[1])
        if segment:
            return self.get_segment_data(data_pack)
        else:
            return data_pack

    def get_data(self, segment=True, clean_label=False, model_label=False, random_noise=False):
        return self.__get_database__(segment, clean_label, model_label, random_noise)

    def get_segment_data(self, data_pack):
        # num_level x sample_num x length (x n_features)
        data, label, loc = data_pack.data, data_pack.label, data_pack.loc

        new_data_pack = DataPack()
        new_data_pack.data = []
        new_data_pack.label = []
        new_data_pack.loc = []
        new_data_pack.patient_list = data_pack.patient_list

        start = np.arange(0, data.shape[-2] - self.window_len + 1, self.slide_len)
        end = start + self.window_len
        start_end_pair = list(zip(start, end))

        for s, e in start_end_pair:
            new_data_pack.data.append(data[:, :, s:e])
            new_data_pack.label.append(label[:, :, s:e])
            tmp_loc = loc.copy()
            tmp_loc[:, :, -1] = tmp_loc[:, :, -2] + e
            tmp_loc[:, :, -2] += s
            new_data_pack.loc.append(tmp_loc)

        # segment_num x num_level x sample_num x window_size x n_features ->
        # num_level x sample_num x segment_num x window_size x n_features
        new_data_pack.data = np.stack(new_data_pack.data, axis=-3)
        new_data_pack.data = new_data_pack.data.transpose(0, 1, 2, 4, 3)
        # num_level x sample_num x segment_num
        new_data_pack.label = np.stack(new_data_pack.label, axis=-2)
        max_label = new_data_pack.label.max(axis=-1)
        min_label = new_data_pack.label.min(axis=-1)
        mean_label = new_data_pack.label.mean(axis=-1)
        flag_label = (mean_label >= (max_label + min_label) / 2)
        new_data_pack.label = max_label * flag_label + min_label * ~flag_label
        # num_level x sample_num x segment_num x 3/4
        new_data_pack.loc = np.stack(new_data_pack.loc, axis=-2)

        print('SMALL Segment Number for each big segment:', new_data_pack.data.shape[-3])
        return new_data_pack

    @staticmethod
    def model_evaluation(true_label, pred_label, n_class):
        return class_evaluation(true_label, pred_label, n_class)
