#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.

import glob
import logging
import numpy as np
import os
import random
import torch
import torch.utils.data
import pandas as pd
import pyvista as pv
import progressbar as pb
from multiprocessing import *
import blosc
from collections import defaultdict
from tqdm import trange
from .normalizers import StandardScalerNormalizer

class OASISBrainDataset(torch.utils.data.Dataset):
    '''
    This is for making training cases with missingness
    '''
    def __init__(
            self,
            filename_datasource,
            filename_split,
            slt_percentile=None,
            covariate_names=["AGE", "EDUC", "SES", "MMSE", "CDR", "eTIV", "ASF"],
            tgt_var_name='nWBV',
            split='train',
            augtype='none',
            allow_missingness=True,
            padding_muter=True,
    ):
        self.DATASETNANE = 'OASISBrain'
        self.in_geo_features = 0
        self.covariate_names = covariate_names # ['weight', 'age', 'height', 'sex', 'pos']
        self.tgt_var_name = tgt_var_name
        self.used_columns = self.covariate_names + [tgt_var_name] + ['ID', 'PID']
        self.allow_missingness = allow_missingness if 'train' in split else False
        self.padding_muter = padding_muter

        # 1. read train values, filter out nan values
        self.split = self.load_yaml_as_dict(filename_split)[split]#[0:100]
        self.train_split = self.load_yaml_as_dict(filename_split)['train']
        self.filename_datasource = filename_datasource

        df_data = self.read_data(self.split)
        df_data_train = self.read_data(self.train_split)
        # 2. get normalizer
        self.dict_covariate_normalizer = self.get_statistics_of_covariates(df_data_train)
        self.dict_geo_normalizer = self.get_statistics_of_geo_features(df_data_train)
        self.dict_normalizer = {**self.dict_geo_normalizer, **self.dict_covariate_normalizer}

        # 3. normalize data
        self.train_valid_dict_features, \
        self.train_valid_arr_features, \
        self.train_normed_data = \
        self.normalize(df_data_train)

        self.valid_dict_features, \
        self.valid_arr_features, \
        self.normed_data = \
        self.normalize(df_data)



        # 4. augment the dataset
        # self.pd_aug_data = self.make_augment_dataset(self.normed_data)
        # num_of_cases = self.count_aug_cases(self.normed_data)
        self.customize_aug(augtype=augtype)
        self.prepared_data_with_nans = self.prepared_data.copy()

        # 5. process missingness
        '''
        If allow_missingness is set to True, it means the dataset will give samples with missingness. 
        The missing covariates are indicated by the muter.
        If it is set to False, the dataset just give complete samples
        '''

        if self.allow_missingness:
            self.prepared_data = self.prepared_data.fillna(0)
        else:
            self.prepared_data = self.prepared_data.dropna()
        self.NUM_OF_CASES = len(self.prepared_data)
        print("There are " + str(len(self.normed_data)) + " records.")
        print("There are " + str(self.NUM_OF_CASES) + " records after augmentation or filtering.")

        self.normalize_unique()

        # # 6. unique covariates pairs
        # self.unique_arr_covariate_pairs = np.unique(self.valid_arr_features, axis=0)
        # self.unique_dict_covariate_pairs = {}
        # for i in range(len(self.covariate_names)):
        #     self.unique_dict_covariate_pairs[self.covariate_names[i]] = self.unique_arr_covariate_pairs[:, [i]]


    @staticmethod
    def count_aug_cases(pd_data):
        # for every cases, how many valid features there are
        num_of_valid_features = np.array(list(pd_data.count(axis=1))) - 3
        num_of_cases = np.sum(np.power(2, num_of_valid_features) - 1)
        return num_of_cases


    def __len__(self):
        return self.NUM_OF_CASES  #self.count_aug_cases(self.normed_data) #len(self.valid_ids)

    def customize_aug(self, augtype='full'):
        if augtype == 'full':
            self.prepared_data = self.fully_augment_dataset(self.normed_data)
        elif augtype == 'none':
            self.prepared_data = self.not_augment_dataset(self.normed_data)


    def get_statistics_of_covariates(self, df_data_train):
        train_ids = np.unique(df_data_train['ID'])
        list_unique_covariates = []

        for ith_id in train_ids:
            list_unique_covariates.append(df_data_train[df_data_train['ID'] == ith_id].iloc[0])
        df_unique_covariates = pd.DataFrame.from_records(list_unique_covariates)
        self.unique_covariates = df_unique_covariates[self.covariate_names]

        self.dict_covariate_normalizer = {}

        for ith_covariate in self.covariate_names:
            current_cov_val = df_unique_covariates[ith_covariate].dropna().values.reshape(-1, 1)
            self.dict_covariate_normalizer[ith_covariate] = StandardScalerNormalizer() #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='normal') #  # #StandardScaler() #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='uniform') ##SmoothCopulaNormalizer()
            self.dict_covariate_normalizer[ith_covariate].fit(current_cov_val)

        return self.dict_covariate_normalizer

    def normalize_unique(self):
        for ith_cov in self.covariate_names:
            self.unique_covariates[ith_cov] = self.dict_covariate_normalizer[ith_cov].transform(self.unique_covariates[ith_cov].values.reshape(-1,1))
        return


    # def get_statistics_of_geo_features(self, df_data_train):
    #     names_of_geo_features = [self.tgt_var_name]
    #     self.dict_geo_feat_mean = {}
    #     self.dict_geo_feat_std = {}
    #     for ith_geo_feat in names_of_geo_features:
    #         self.dict_geo_feat_mean[ith_geo_feat] = df_data_train[ith_geo_feat].dropna().values.mean()
    #         self.dict_geo_feat_std[ith_geo_feat] = df_data_train[ith_geo_feat].dropna().values.std()
    #     return self.dict_geo_feat_mean, self.dict_geo_feat_std


    def get_statistics_of_geo_features(self, df_data_train):
        names_of_geo_features = [self.tgt_var_name]
        self.dict_geo_normalizer = {}
        for ith_geo_feat in names_of_geo_features:
            current_geo_val = df_data_train[ith_geo_feat].dropna().values.reshape(-1, 1)
            self.dict_geo_normalizer[ith_geo_feat] = StandardScalerNormalizer() # RobustNormalizer() # StandardScalerNormalizer() #RobustScaler() #StandardScaler()  ##QuantileTransformer(output_distribution='uniform') #RobustScaler() #PowerTransformer(method='yeo-johnson') #RobustScaler()#SmoothCopulaNormalizer() # #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='normal') #SmoothCopulaNormalizer()
            #current_geo_val = np.clip(current_geo_val, a_min=current_geo_val.min(), a_max=np.quantile(current_geo_val, 0.95))
            self.dict_geo_normalizer[ith_geo_feat].fit(current_geo_val)

        return self.dict_geo_normalizer


    def normalize(self, df_data):

        list_covariates = []
        dict_covariates = {}
        for ith_cov in self.covariate_names:
            arr_current_cov = np.array(df_data[ith_cov])
            arr_normed_cov = self.dict_covariate_normalizer[ith_cov].transform(arr_current_cov.reshape(-1, 1))
            list_covariates.append(arr_normed_cov.squeeze())
            dict_covariates[ith_cov] = arr_normed_cov.squeeze()

        arr_covariates = np.array(list_covariates).T
        vol_values = self.dict_geo_normalizer[self.tgt_var_name].transform(np.array(df_data[self.tgt_var_name]).reshape(-1, 1))

        dict_normalized_pd_data = {'ID': df_data['ID'], 'PID': df_data['PID']}
        dict_normalized_pd_data.update(dict_covariates)
        dict_normalized_pd_data.update({self.tgt_var_name: vol_values.squeeze()})
        pd_data = pd.DataFrame.from_dict(dict_normalized_pd_data)
        return dict_covariates, arr_covariates, pd_data
    def binary_muter(self, N):
        binaires = np.binary_repr(N, width=len(self.covariate_names))
        arr = np.array(list(binaires)).astype('float')
        return arr

    def fully_augment_dataset(self, pd_data_ori):

        pd_data = pd_data_ori.copy()
        #pd_data.fillna(0, inplace=True)
        list_data_dict = pd_data.to_dict('records')
        list_data = []
        N_features = len(self.covariate_names)
        N_AUGMENT = 2**N_features - 1
        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []
            for ith_aug in range(0, N_AUGMENT): # 1 to 2^N-1 augmentations
                BINARY_MUTATION = self.binary_muter(ith_aug)
                current_aug_data = ith_arr_data.copy()
                current_aug_data[BINARY_MUTATION == 1.] = np.nan # 1 is nan and 0 is good

                #current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
                if self.padding_muter:
                    current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
                else:
                    current_aug_data_with_muter = current_aug_data

                dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
                dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
                dict_current_data['ID'] = ith_dict_data['ID']
                list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + [self.tgt_var_name, 'ID'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug


    def not_augment_dataset(self, pd_data_ori):

        pd_data = pd_data_ori.copy()
        list_data_dict = pd_data.to_dict('records')
        list_data = []

        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []

            current_aug_data = ith_arr_data.copy()
            BINARY_MUTATION = np.isnan(current_aug_data)

            #current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
            if self.padding_muter:
                current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
            else:
                current_aug_data_with_muter = current_aug_data

            dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
            dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
            dict_current_data['ID'] = ith_dict_data['ID']
            dict_current_data['PID'] = ith_dict_data['PID']
            list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + [self.tgt_var_name, 'ID'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug



    def make_dict_from_arr_data(self, arr):
        assert len(arr) == len(self.covariate_names) * 2 or len(arr) == len(self.covariate_names)
        if len(arr) == len(self.covariate_names):
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
        elif len(arr) == len(self.covariate_names) * 2:
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx] + '_muter'] = arr[idx + len(self.covariate_names)]

        return dict_data



    def read_data(self, split):
        # read data
        #["AGE", "EDUC", "SES", "MMSE", "CDR", "eTIV", "ASF"],
        types = {"ID": str, 'PID': str,  "AGE": float, "SEX": float, "EDUC": float, "SES": float, "MMSE": float, "CDR": float, "eTIV": float, "ASF": float, "nWBV": float} #str, id="str", weight="float")
        self.df_data = pd.read_csv(self.filename_datasource, header=0, dtype=types)
        # select needed columns and rows
        df_data_split = self.df_data.loc[self.df_data['ID'].isin(split)]
        df_data_split = df_data_split[self.used_columns]
        df_data_split = df_data_split.dropna(subset=['ID', self.tgt_var_name])
        return df_data_split

    def load_yaml_as_dict(self, yaml_path):
        import yaml
        with open(yaml_path, 'r') as f:
            config_dict = yaml.load(f, Loader=yaml.FullLoader)
        return config_dict

    def __getitem__(self, idx):

        # loading features
        list_covariates = []
        for ith_cov in range(len(self.covariate_names)):
            list_covariates.append(torch.tensor(self.prepared_data[self.covariate_names[ith_cov]].iloc[idx]).float())


        if self.padding_muter:
            # loading muters
            list_covariate_muters = []
            for ith_cov in range(len(self.covariate_names)):
                list_covariate_muters.append(
                    torch.tensor(self.prepared_data[self.covariate_names[ith_cov] + '_muter'].iloc[idx]).float())

            model_input = torch.tensor(list_covariates + list_covariate_muters).float()

        else:
            model_input = torch.tensor(list_covariates).float()



        vol = torch.tensor(self.prepared_data[self.tgt_var_name].iloc[idx]).float()[None,...]

        return model_input, vol


def load_yaml_as_dict(yaml_path):
    import yaml
    with open(yaml_path, "r") as stream:
        config_dict = yaml.safe_load(stream)
    return config_dict



def get_OASISBrain_for_transport(spec, ds_dataset, split='test_multiple'):
    filename_split = spec["Split"]
    timelines = load_yaml_as_dict(filename_split)[split]
    #df_data = pd.read_csv(filename_datasource, header=0)
    df_data = ds_dataset.prepared_data #_with_nans
    #df_data = df_data.dropna()
    df_data = df_data.drop_duplicates(subset=['ID'])
    list_scans = []
    list_patient_scans = []
    for patient in timelines:
        list_scans += patient['value']
        df_data_split = df_data.loc[df_data['PID'].astype('str') == patient['name']]
        if len(df_data_split) < 2:
            continue
        ages = np.array(df_data_split['AGE'].values)
        try:
            youngest_scan = df_data_split.loc[df_data_split['AGE'] == ages.min()]
        except:
            print('1')
        other_scans = df_data_split.loc[df_data_split['AGE'] > ages.min()]
        other_scans = other_scans['ID'].values[np.argsort(other_scans['AGE'])]
        current_dict = {'patient': patient['name'],
                        'youngest_scan': youngest_scan['ID'].values[0],
                        'other_scans': other_scans}
        list_patient_scans.append(current_dict)
    return list_patient_scans


def get_OASISBrain_pairs_for_transport(spec, ds_dataset, split='test_multiple'):
    np.random.seed(1117)
    filename_split = spec["Split"]
    timelines = load_yaml_as_dict(filename_split)[split]
    #df_data = pd.read_csv(filename_datasource, header=0)
    df_data = ds_dataset.prepared_data #_with_nans
    #df_data = df_data.dropna()
    df_data = df_data.drop_duplicates(subset=['ID'])
    list_scans = []
    list_patient_scans = []
    for patient in timelines:
        list_scans += patient['value']
        df_data_split = df_data.loc[df_data['PID'].astype('str') == patient['name']]
        if len(df_data_split) < 2:
            continue

        ages = np.array(df_data_split['AGE'].values)
        sorted_ages = np.sort(df_data_split['AGE'].values)

        for ith_age in range(len(sorted_ages)-1):
            src_scan = df_data_split.loc[df_data_split['AGE'] == sorted_ages[ith_age]]
            idx_rdm = np.random.randint(len(ages) - ith_age - 1) + 1
            tgt_scan = df_data_split.loc[df_data_split['AGE'] == sorted_ages[ith_age+idx_rdm]]

            current_dict = {'patient': patient['name'],
                            'src_scan': src_scan['ID'].values[0],
                            'tgt_scan': tgt_scan['ID'].values[0],}
            list_patient_scans.append(current_dict)
    return list_patient_scans





def make_oasisbrain_model_input(ds_, ds_prepared_data, tgt_var_name):
    list_covariate_names = ds_.covariate_names


    # loading features
    list_covariates = []
    for ith_cov in range(len(list_covariate_names)):
        list_covariates.append(torch.tensor(ds_prepared_data[list_covariate_names[ith_cov]].values).float())

    if ds_.padding_muter:
        # loading muters
        list_covariate_muters = []
        for ith_cov in range(len(list_covariate_names)):
            list_covariate_muters.append(
                torch.tensor(ds_prepared_data[list_covariate_names[ith_cov] + '_muter'].values).float())
        model_input = torch.stack(list_covariates + list_covariate_muters, dim=-1)
    else:
        model_input = torch.stack(list_covariates, dim=-1).float()

    vol = torch.tensor(ds_prepared_data[tgt_var_name].values).float()[..., None]

    return model_input, vol


def get_oasisbrain_data_for_id(test_idx, ds_, pos=None):
    pd_ds = ds_.prepared_data
    pd_slt_data = pd_ds[pd_ds['ID'] == test_idx]
    model_input, csa = make_oasisbrain_model_input(ds_, pd_slt_data,  ds_.tgt_var_name)
    if len(model_input) != len(pd_slt_data):
        print('1')
    return model_input, csa, pd_slt_data



if __name__ == "__main__":
    ds_train = OASISBrainDataset(
        filename_datasource="/playpen-raid/Author/LucidAtlas/data/OASISBrain/data.csv",
        filename_split="/playpen-raid/Author/LucidAtlas/data/OASISBrain/split.yaml",
        covariate_names=["AGE", "SEX", "EDUC", "SES", "MMSE", "eTIV", "ASF"],
        tgt_var_name="CDR",
        split='train')
    print(len(ds_train))