#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.


import torch
import numpy as np
from scipy.stats import norm

import torch
import torch.utils.data
import pandas as pd
import pyvista as pv
import progressbar as pb
from multiprocessing import *
import blosc
from collections import defaultdict
from tqdm import trange
from scipy.interpolate import interp1d
import numpy as np
from .normalizers import StandardScalerNormalizer
def get_train_split(dict_):
    if 'train_val' in list(dict_.keys()):
        return dict_['train_val']
    else:
        return dict_['train']






class PediatricAirwayDataset(torch.utils.data.Dataset):
    '''
    This is for making training cases with missingness
    '''
    def __init__(
            self,
            filename_datasource,
            filename_split,
            slt_percentile=None,
            covariate_names=['AGE', 'WEIGHT', 'HEIGHT', 'SEX'],
            tgt_var_name='csa',
            split='train',
            augtype='none',
            allow_missingness=True,
            padding_muter=True,
    ):
        self.DATASETNANE = 'Airway'
        self.in_geo_features = 1
        self.geo_var_name = 'pos'
        self.covariate_names = covariate_names # ['weight', 'age', 'height', 'sex', 'pos']
        self.tgt_var_name = tgt_var_name
        self.used_columns = self.covariate_names + [tgt_var_name] + ['pos', 'id', 'PID']
        self.allow_missingness = allow_missingness if 'train' in split else False
        self.padding_muter = padding_muter

        # 1. read train values, filter out nan values
        self.split = self.load_yaml_as_dict(filename_split)[split]
        self.train_split = get_train_split(self.load_yaml_as_dict(filename_split))#['train']
        self.filename_datasource = filename_datasource

        df_data = self.read_data(self.split)
        df_data_train = self.read_data(self.train_split)
        # df_data_train, df_data = self.filter(df_data_train, df_data)
        #print("There are " + str(len(df_data)) + " initial records.")

        # 2. get normalizer
        self.dict_covariate_normalizer = self.get_statistics_of_covariates(df_data_train)
        self.dict_geo_normalizer = self.get_statistics_of_geo_features(df_data_train)
        self.dict_normalizer = {**self.dict_geo_normalizer, **self.dict_covariate_normalizer}

        # 3. normalize data
        self.train_valid_dict_features, \
        self.train_valid_arr_features, \
        self.train_normed_data = \
        self.normalize(df_data_train, slt_percentile)

        self.valid_dict_features, \
        self.valid_arr_features, \
        self.normed_data = \
        self.normalize(df_data, slt_percentile)

        # 4. augment the dataset
        # self.pd_aug_data = self.make_augment_dataset(self.normed_data)
        # num_of_cases = self.count_aug_cases(self.normed_data)
        self.customize_aug(augtype=augtype)
        self.prepared_data_with_nans = self.prepared_data.copy()


        # 5. process missingness
        '''
        If allow_missingness is set to True, it means the dataset will give samples with missingness. 
        The missing covariates are indicated by the muter.
        If it is set to False, the dataset just give complete samples
        '''
        if self.allow_missingness:
            self.prepared_data = self.prepared_data.fillna(0)
        else:
            self.prepared_data = self.prepared_data.dropna()
        self.NUM_OF_CASES = len(self.prepared_data)
        print("There are " + str(len(self.normed_data)) + " records.")
        print("There are " + str(self.NUM_OF_CASES) + " records after augmentation or filtering.")

        print("There are " + str(len(np.unique(self.prepared_data_with_nans['PID']))) + "patients.")
        print("There are " + str(len(np.unique(self.prepared_data['PID'])))  + " patients after augmentation or filtering.")


        self.normalize_unique()

        self.train_valid_pos = np.array(self.train_normed_data['pos'].values)


    # def filter(self, df_data_train, df_data):
    #     # top = np.quantile(df_data_train[self.tgt_var_name].values, 0.95)
    #     # df_data_train = df_data_train[df_data_train[self.tgt_var_name] < top]
    #     # df_data = df_data[df_data[self.tgt_var_name] < top]
    #
    #     df_data_train = df_data_train[np.abs(df_data_train['pos'].values-0.7)< 0.02]
    #     df_data = df_data[np.abs(df_data['pos'].values - 0.7) < 0.02]
    #     return df_data_train, df_data

    @staticmethod
    def count_aug_cases(pd_data):
        # for every cases, how many valid features there are
        num_of_valid_features = np.array(list(pd_data.count(axis=1))) - 3
        num_of_cases = np.sum(np.power(2, num_of_valid_features) - 1)
        return num_of_cases


    def __len__(self):
        return self.NUM_OF_CASES  #self.count_aug_cases(self.normed_data) #len(self.valid_ids)


    def customize_aug(self, augtype='full'):
        if augtype == 'full':
            self.prepared_data = self.fully_augment_dataset(self.normed_data)
            #self.NUM_OF_CASES = len(self.prepared_data) #self.count_aug_cases(self.normed_data)
        elif augtype == 'none':
            self.prepared_data = self.not_augment_dataset(self.normed_data)
            #self.count_aug_cases(self.normed_data)

    def get_statistics_of_covariates(self, df_data_train):
        train_ids = np.unique(df_data_train['id'])
        list_unique_covariates = []

        for ith_id in train_ids:
            list_unique_covariates.append(df_data_train[df_data_train['id'] == ith_id].iloc[0])
        df_unique_covariates = pd.DataFrame.from_records(list_unique_covariates)
        self.unique_covariates = df_unique_covariates[self.covariate_names]

        self.dict_covariate_normalizer = {}

        for ith_covariate in self.covariate_names:
            current_cov_val = df_unique_covariates[ith_covariate].dropna().values.reshape(-1, 1)
            self.dict_covariate_normalizer[ith_covariate] = StandardScalerNormalizer() #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='normal') #  # #StandardScaler() #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='uniform') ##SmoothCopulaNormalizer()
            self.dict_covariate_normalizer[ith_covariate].fit(current_cov_val)

        return self.dict_covariate_normalizer

    def normalize_unique(self):
        for ith_cov in self.covariate_names:
            self.unique_covariates[ith_cov] = self.dict_covariate_normalizer[ith_cov].transform(self.unique_covariates[ith_cov].values.reshape(-1,1))
        return

    def get_statistics_of_geo_features(self, df_data_train):
        names_of_geo_features = ['pos', self.tgt_var_name]

        self.dict_geo_normalizer = {}

        for ith_geo_feat in ['pos']:
            current_geo_val = df_data_train[ith_geo_feat].dropna().values.reshape(-1, 1)
            self.dict_geo_normalizer[ith_geo_feat] = StandardScalerNormalizer() #QuantileTransformer(output_distribution='uniform') #SmoothCopulaNormalizer()
            #current_geo_val = np.clip(current_geo_val, a_min=current_geo_val.min(), a_max=np.quantile(current_geo_val, 0.95))
            self.dict_geo_normalizer[ith_geo_feat].fit(current_geo_val)

        for ith_geo_feat in [self.tgt_var_name]:
            current_geo_val = df_data_train[ith_geo_feat].dropna().values.reshape(-1, 1)
            self.dict_geo_normalizer[ith_geo_feat] = StandardScalerNormalizer() # RobustNormalizer() # StandardScalerNormalizer() #RobustScaler() #StandardScaler()  ##QuantileTransformer(output_distribution='uniform') #RobustScaler() #PowerTransformer(method='yeo-johnson') #RobustScaler()#SmoothCopulaNormalizer() # #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='normal') #SmoothCopulaNormalizer()
            #current_geo_val = np.clip(current_geo_val, a_min=current_geo_val.min(), a_max=np.quantile(current_geo_val, 0.95))
            self.dict_geo_normalizer[ith_geo_feat].fit(current_geo_val)


        return self.dict_geo_normalizer

    def normalize(self, df_data, slt_percentile):
        if slt_percentile is not None:
            current_pos = np.percentile(df_data['pos'].values, slt_percentile)
            df_data = df_data[df_data['pos'] == current_pos]
        else:
            df_data = df_data

        list_covariates = []
        dict_covariates = {}
        for ith_cov in self.covariate_names:
            arr_current_cov = np.array(df_data[ith_cov])
            arr_normed_cov = self.dict_covariate_normalizer[ith_cov].transform(arr_current_cov.reshape(-1, 1))
            list_covariates.append(arr_normed_cov.squeeze())
            dict_covariates[ith_cov] = arr_normed_cov.squeeze()

        arr_covariates = np.array(list_covariates).T

        csa_values = self.dict_geo_normalizer[self.tgt_var_name].transform(np.array(df_data[self.tgt_var_name]).reshape(-1, 1))
        pos_values = self.dict_geo_normalizer[self.geo_var_name].transform(np.array(df_data[self.geo_var_name]).reshape(-1, 1)) #np.array(df_data['pos'])

        dict_normalized_pd_data = {'id': df_data['id'], 'PID': df_data['PID']}
        dict_normalized_pd_data.update(dict_covariates)
        dict_normalized_pd_data.update({'pos': pos_values.squeeze()})
        dict_normalized_pd_data.update({self.tgt_var_name: csa_values.squeeze()})
        pd_data = pd.DataFrame.from_dict(dict_normalized_pd_data)
        return dict_covariates, arr_covariates, pd_data

    def make_cases(self, pd_data):
        list_data_dict = pd.DataFrame.to_records(pd_data)
        list_data = []
        for ith_data in list_data_dict:
            current_data = []
            for ith_cov in self.covariate_names:
                current_data.append(ith_data[ith_cov])
            current_data.append(ith_data['pos'])
            list_data.append(current_data)
        return list_data

    def binary_muter(self, N):
        binaires = np.binary_repr(N, width=len(self.covariate_names))
        arr = np.array(list(binaires)).astype('float')
        return arr

    def fully_augment_dataset(self, pd_data_ori):

        num_of_valid_features = pd_data_ori.count(axis=1) - 3
        pd_data = pd_data_ori.copy()
        #pd_data.fillna(0, inplace=True)
        list_data_dict = pd_data.to_dict('records')
        list_data = []
        N_features = len(self.covariate_names)
        N_AUGMENT = 2**N_features - 1
        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []
            for ith_aug in range(0, N_AUGMENT): # 1 to 2^N-1 augmentations
                BINARY_MUTATION = self.binary_muter(ith_aug)
                current_aug_data = ith_arr_data.copy()
                current_aug_data[BINARY_MUTATION == 1.] = np.nan # 1 is nan and 0 is good
                #current_aug_data = ith_arr_data * BINARY_MUTATION

                #current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
                if self.padding_muter:
                    current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
                else:
                    current_aug_data_with_muter = current_aug_data

                dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
                dict_current_data['pos'] = ith_dict_data['pos']
                dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
                dict_current_data['id'] = ith_dict_data['id']
                list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + ['pos', self.tgt_var_name, 'id'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug


    def not_augment_dataset(self, pd_data_ori):

        pd_data = pd_data_ori.copy()
        list_data_dict = pd_data.to_dict('records')
        list_data = []

        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []

            current_aug_data = ith_arr_data.copy()
            BINARY_MUTATION = np.isnan(current_aug_data)
            if self.padding_muter:
                current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
            else:
                current_aug_data_with_muter = current_aug_data
            dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
            dict_current_data['pos'] = ith_dict_data['pos']
            dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
            dict_current_data['id'] = ith_dict_data['id']
            dict_current_data['PID'] = ith_dict_data['PID']
            list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + ['pos', self.tgt_var_name, 'id'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug



    def make_dict_from_arr_data(self, arr):
        assert len(arr) == len(self.covariate_names) * 2 or len(arr) == len(self.covariate_names)
        if len(arr) == len(self.covariate_names):
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
        elif len(arr) == len(self.covariate_names) * 2:
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx] + '_muter'] = arr[idx + len(self.covariate_names)]

        return dict_data



    def read_data(self, split):
        # read data
        types = {"id": str, "WEIGHT": float, "AGE": float, "SEX": float, "HEIGHT": float, "csa": float, "pos": float, 'PID': str} #str, id="str", weight="float")
        self.df_data = pd.read_excel(self.filename_datasource, header=0, dtype=types)
        # select needed columns and rows
        df_data_split = self.df_data.loc[self.df_data['id'].isin(split)]
        df_data_split = df_data_split[self.used_columns]
        df_data_split = df_data_split.dropna(subset=['id', self.tgt_var_name, 'pos'])
        return df_data_split

    def load_yaml_as_dict(self, yaml_path):
        import yaml
        with open(yaml_path, 'r') as f:
            config_dict = yaml.load(f, Loader=yaml.FullLoader)
        return config_dict

    def __getitem__(self, idx):

        # loading features
        list_covariates = []
        for ith_cov in range(len(self.covariate_names)):
            list_covariates.append(torch.tensor(self.prepared_data[self.covariate_names[ith_cov]].iloc[idx]).float())

        coords = torch.tensor(self.prepared_data['pos'].iloc[idx]).float()

        if self.padding_muter:
            # loading muters
            list_covariate_muters = []
            for ith_cov in range(len(self.covariate_names)):
                list_covariate_muters.append(
                    torch.tensor(self.prepared_data[self.covariate_names[ith_cov] + '_muter'].iloc[idx]).float())
            model_input = torch.tensor([coords] + list_covariates + list_covariate_muters).float()
        else:
            model_input = torch.tensor([coords] + list_covariates).float()

        csa = torch.tensor(self.prepared_data[self.tgt_var_name].iloc[idx]).float()[None,...]

        return model_input, csa


def load_yaml_as_dict(yaml_path):
    import yaml
    with open(yaml_path, "r") as stream:
        config_dict = yaml.safe_load(stream)
    return config_dict



def get_airways_for_transport(spec, ds_dataset, split='test_multiple'):
    filename_split = spec["Split"]
    timelines = load_yaml_as_dict(filename_split)[split]
    df_data = ds_dataset.prepared_data #_with_nans
    #df_data = df_data.dropna()
    df_data = df_data.drop_duplicates(subset=['id'])
    list_scans = []
    list_patient_scans = []
    for patient in timelines:
        list_scans += patient['value']
        df_data_split = df_data.loc[df_data['PID'].astype('str') == patient['name']]
        if len(df_data_split) < 2:
            continue
        ages = np.array(df_data_split['AGE'].values)
        youngest_scan = df_data_split.loc[df_data_split['AGE'] == ages.min()]
        other_scans = df_data_split.loc[df_data_split['AGE'] > ages.min()]
        other_scans = other_scans['id'].values[np.argsort(other_scans['AGE'])]
        current_dict = {'patient': patient['name'],
                        'youngest_scan': youngest_scan['id'].values[0],
                        'other_scans': other_scans}
        list_patient_scans.append(current_dict)
    return list_patient_scans


def get_airways_pairs_for_transport(spec, ds_dataset, split='test_multiple'):
    np.random.seed(1117)
    filename_split = spec["Split"]
    timelines = load_yaml_as_dict(filename_split)[split]
    #df_data = pd.read_csv(filename_datasource, header=0)
    df_data = ds_dataset.prepared_data #_with_nans
    #df_data = df_data.dropna()
    df_data = df_data.drop_duplicates(subset=['id'])
    list_scans = []
    list_patient_scans = []
    for patient in timelines:
        list_scans += patient['value']
        df_data_split = df_data.loc[df_data['PID'].astype('str') == patient['name']]
        if len(df_data_split) < 2:
            continue
        ages = np.array(df_data_split['AGE'].values)
        sorted_ages = np.sort(df_data_split['AGE'].values)

        for ith_age in range(len(sorted_ages)-1):
            src_scan = df_data_split.loc[df_data_split['AGE'] == sorted_ages[ith_age]]
            idx_rdm = np.random.randint(len(ages) - ith_age - 1) + 1
            tgt_scan = df_data_split.loc[df_data_split['AGE'] == sorted_ages[ith_age+idx_rdm]]

            current_dict = {'patient': patient['name'],
                            'src_scan': src_scan['id'].values[0],
                            'tgt_scan': tgt_scan['id'].values[0],}
            list_patient_scans.append(current_dict)
    return list_patient_scans





def make_airway_model_input(ds_, ds_prepared_data, tgt_var_name):
    list_covariate_names = ds_.covariate_names


    # loading features
    list_covariates = []
    for ith_cov in range(len(list_covariate_names)):
        list_covariates.append(torch.tensor(ds_prepared_data[list_covariate_names[ith_cov]].values).float())

    coords = torch.tensor(ds_prepared_data['pos'].values).float()

    if ds_.padding_muter:
        # loading muters
        list_covariate_muters = []
        for ith_cov in range(len(list_covariate_names)):
            list_covariate_muters.append(
                torch.tensor(ds_prepared_data[list_covariate_names[ith_cov] + '_muter'].values).float())
        model_input = torch.stack([coords] + list_covariates + list_covariate_muters, dim=-1)
    else:
        model_input = torch.stack([coords] + list_covariates, dim=-1)

    csa = torch.tensor(ds_prepared_data[tgt_var_name].values).float()[..., None]

    return model_input, csa


def get_airway_data_for_id(test_idx, ds_, pos=None):
    pd_ds = ds_.prepared_data
    pd_slt_data = pd_ds[pd_ds['id'] == test_idx]
    if pos is not None:
        pd_slt_data = pd_slt_data[pd_slt_data['pos'] == pos]
    model_input, csa = make_airway_model_input(ds_, pd_slt_data, ds_.tgt_var_name)
    if len(model_input) != len(pd_slt_data):
        print('1')
    return model_input, csa, pd_slt_data



if __name__ == "__main__":
    ds_train = PediatricAirwayDataset(
        filename_datasource="/playpen-raid/Author/LucidAtlas/data/airways/csa_atlas_20250409.xlsx",
        filename_split="/playpen-raid/Author/LucidAtlas/data/airways/airway_split_with_val.yaml",
        covariate_names=["AGE", "WEIGHT", "HEIGHT"],
        split='all', allow_missingness=False,padding_muter=False)
