#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.

import glob
import logging
import numpy as np
import os
import random
import torch
import torch.utils.data
import pandas as pd
import pyvista as pv
import progressbar as pb
from multiprocessing import *
import blosc
from collections import defaultdict
from tqdm import trange
def get_train_split(dict_):
    if 'train_val' in list(dict_.keys()):
        return dict_['train_val']
    else:
        return dict_['train']

class OODAirwayDataset(torch.utils.data.Dataset):
    '''
    This is for making training cases with missingness
    '''
    def __init__(
            self,
            filename_atlas_datasource,
            filename_atlas_split,
            filename_ood_datasource,
            filename_ood_split,

            slt_percentile=None,
            covariate_names=['AGE', 'WEIGHT', 'HEIGHT', 'SEX'],
            tgt_var_name='csa',
            split='train',
            augtype='none',
            allow_missingness=True,
            padding_muter=True,
    ):
        self.DATASETNANE = 'Airway'
        self.covariate_names = covariate_names # ['weight', 'age', 'height', 'sex', 'pos']
        self.tgt_var_name = tgt_var_name
        self.used_columns = self.covariate_names + [tgt_var_name] + ['pos', 'id', 'PID']
        self.allow_missingness = allow_missingness if 'train' in split else False
        self.padding_muter = padding_muter

        # 1. read train values, filter out nan values
        # 1.1 atlas
        self.atlas_split = self.load_yaml_as_dict(filename_atlas_split)[split]
        self.atlas_train_split = get_train_split(self.load_yaml_as_dict(filename_atlas_split))#['train']
        self.atlas_filename_datasource = filename_atlas_datasource
        # 1.2 OOD
        self.ood_split = self.load_yaml_as_dict(filename_ood_split)[split]
        self.ood_filename_datasource = filename_ood_datasource

        df_ood_data = self.read_data(self.ood_filename_datasource, self.ood_split)
        df_atlas_data_train = self.read_data(self.atlas_filename_datasource, self.atlas_train_split)
        # 2. get global mean, std
        self.dict_covariate_mean, self.dict_covariate_std = self.get_statistics_of_covariates(df_atlas_data_train)
        self.dict_geo_feat_mean, self.dict_geo_feat_std = self.get_statistics_of_geo_features(df_atlas_data_train)
        self.dict_feat_mean = {**self.dict_geo_feat_mean, **self.dict_covariate_mean}
        self.dict_feat_std = {**self.dict_geo_feat_std, **self.dict_covariate_std}

        # 3. normalize data
        self.train_valid_dict_features, \
        self.train_valid_arr_features, \
        self.train_normed_data = \
        self.normalize(df_atlas_data_train, slt_percentile)

        self.valid_dict_features, \
        self.valid_arr_features, \
        self.normed_data = \
        self.normalize(df_ood_data, slt_percentile)

        # 4. augment the dataset
        # self.pd_aug_data = self.make_augment_dataset(self.normed_data)
        # num_of_cases = self.count_aug_cases(self.normed_data)
        self.customize_aug(augtype=augtype)
        self.prepared_data_with_nans = self.prepared_data.copy()


        # 5. process missingness
        '''
        If allow_missingness is set to True, it means the dataset will give samples with missingness. 
        The missing covariates are indicated by the muter.
        If it is set to False, the dataset just give complete samples
        '''
        if self.allow_missingness:
            self.prepared_data = self.prepared_data.fillna(0)
        else:
            self.prepared_data = self.prepared_data.dropna()
        self.NUM_OF_CASES = len(self.prepared_data)
        print("There are " + str(len(self.normed_data)) + " records.")
        print("There are " + str(self.NUM_OF_CASES) + " records after augmentation or filtering.")

        print("There are " + str(len(np.unique(self.prepared_data_with_nans['PID']))) + " patients.")
        print("There are " + str(len(np.unique(self.prepared_data['PID']))) + " patients after augmentation or filtering.")

        print("There are " + str(len(np.unique(self.prepared_data_with_nans['id']))) + " shapes.")
        print("There are " + str(len(np.unique(self.prepared_data['id']))) + " shapes after augmentation or filtering.")


        # 6. unique covariates pairs
        self.unique_arr_covariate_pairs = np.unique(self.valid_arr_features, axis=0)
        self.unique_dict_covariate_pairs = {}
        for i in range(len(self.covariate_names)):
            self.unique_dict_covariate_pairs[self.covariate_names[i]] = self.unique_arr_covariate_pairs[:, [i]]
        self.train_valid_pos = np.array(self.train_normed_data['pos'].values)


    @staticmethod
    def count_aug_cases(pd_data):
        # for every cases, how many valid features there are
        num_of_valid_features = np.array(list(pd_data.count(axis=1))) - 3
        num_of_cases = np.sum(np.power(2, num_of_valid_features) - 1)
        return num_of_cases


    def __len__(self):
        return self.NUM_OF_CASES  #self.count_aug_cases(self.normed_data) #len(self.valid_ids)




    def customize_aug(self, augtype='full'):
        if augtype == 'full':
            self.prepared_data = self.fully_augment_dataset(self.normed_data)
            #self.NUM_OF_CASES = len(self.prepared_data) #self.count_aug_cases(self.normed_data)
        elif augtype == 'none':
            self.prepared_data = self.not_augment_dataset(self.normed_data)
            #self.count_aug_cases(self.normed_data)





    def get_statistics_of_covariates(self, df_data_train):
        train_ids = np.unique(df_data_train['id'])
        list_unique_covariates = []

        for ith_id in train_ids:
            list_unique_covariates.append(df_data_train[df_data_train['id'] == ith_id].iloc[0])
        df_unique_covariates = pd.DataFrame.from_records(list_unique_covariates)

        self.dict_covariate_mean = {}
        self.dict_covariate_std = {}
        for ith_covariate in self.covariate_names:
            self.dict_covariate_mean[ith_covariate] = df_unique_covariates[ith_covariate].dropna().values.mean()
            self.dict_covariate_std[ith_covariate] = df_unique_covariates[ith_covariate].dropna().values.std()

        return self.dict_covariate_mean, self.dict_covariate_std

    def get_statistics_of_geo_features(self, df_data_train):
        names_of_geo_features = ['pos', self.tgt_var_name]
        self.dict_geo_feat_mean = {}
        self.dict_geo_feat_std = {}
        for ith_geo_feat in names_of_geo_features:
            self.dict_geo_feat_mean[ith_geo_feat] = df_data_train[ith_geo_feat].dropna().values.mean()
            self.dict_geo_feat_std[ith_geo_feat] = df_data_train[ith_geo_feat].dropna().values.std()
        return self.dict_geo_feat_mean, self.dict_geo_feat_std

    def normalize(self, df_data, slt_percentile):
        if slt_percentile is not None:
            current_pos = np.percentile(df_data['pos'].values, slt_percentile)
            df_data = df_data[df_data['pos'] == current_pos]
        else:
            df_data = df_data

        list_covariates = []
        dict_covariates = {}
        for ith_cov in self.covariate_names:
            arr_current_cov = np.array(df_data[ith_cov])
            # if ith_cov != "SEX":
            #     arr_normed_cov = (arr_current_cov - self.dict_covariate_mean[ith_cov]) / self.dict_covariate_std[ith_cov]
            # else:
            #     arr_normed_cov = arr_current_cov
            arr_normed_cov = (arr_current_cov - self.dict_covariate_mean[ith_cov]) / self.dict_covariate_std[ith_cov]
            list_covariates.append(arr_normed_cov)
            dict_covariates[ith_cov] = arr_normed_cov

        arr_covariates = np.array(list_covariates).T

        csa_values = (np.array(df_data[self.tgt_var_name]) - self.dict_geo_feat_mean[self.tgt_var_name]) / self.dict_geo_feat_std[self.tgt_var_name]
        pos_values = (np.array(df_data['pos']) - self.dict_geo_feat_mean['pos']) / self.dict_geo_feat_std['pos']
        id_values = np.array(df_data['id'])


        dict_normalized_pd_data = {'id': df_data['id'], 'PID': df_data['PID']}
        dict_normalized_pd_data.update(dict_covariates)
        dict_normalized_pd_data.update({'pos': pos_values})
        dict_normalized_pd_data.update({self.tgt_var_name: csa_values})
        pd_data = pd.DataFrame.from_dict(dict_normalized_pd_data)
        #return pos_values[..., None], dict_covariates, arr_covariates, csa_values, id_values, pd_data
        return dict_covariates, arr_covariates, pd_data

    def make_cases(self, pd_data):
        list_data_dict = pd.DataFrame.to_records(pd_data)
        list_data = []
        for ith_data in list_data_dict:
            current_data = []
            for ith_cov in self.covariate_names:
                current_data.append(ith_data[ith_cov])
            current_data.append(ith_data['pos'])
            list_data.append(current_data)
        return list_data

    def binary_muter(self, N):
        binaires = np.binary_repr(N, width=len(self.covariate_names))
        arr = np.array(list(binaires)).astype('float')
        return arr

    def fully_augment_dataset(self, pd_data_ori):

        num_of_valid_features = pd_data_ori.count(axis=1) - 3
        pd_data = pd_data_ori.copy()
        #pd_data.fillna(0, inplace=True)
        list_data_dict = pd_data.to_dict('records')
        list_data = []
        N_features = len(self.covariate_names)
        N_AUGMENT = 2**N_features - 1
        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []
            for ith_aug in range(0, N_AUGMENT): # 1 to 2^N-1 augmentations
                BINARY_MUTATION = self.binary_muter(ith_aug)
                current_aug_data = ith_arr_data.copy()
                current_aug_data[BINARY_MUTATION == 1.] = np.nan # 1 is nan and 0 is good
                #current_aug_data = ith_arr_data * BINARY_MUTATION

                #current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
                if self.padding_muter:
                    current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
                else:
                    current_aug_data_with_muter = current_aug_data

                dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
                dict_current_data['pos'] = ith_dict_data['pos']
                dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
                dict_current_data['id'] = ith_dict_data['id']
                list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + ['pos', self.tgt_var_name, 'id'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug


    def not_augment_dataset(self, pd_data_ori):

        pd_data = pd_data_ori.copy()
        list_data_dict = pd_data.to_dict('records')
        list_data = []

        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []

            current_aug_data = ith_arr_data.copy()
            BINARY_MUTATION = np.isnan(current_aug_data)
            if self.padding_muter:
                current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
            else:
                current_aug_data_with_muter = current_aug_data
            dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
            dict_current_data['pos'] = ith_dict_data['pos']
            dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
            dict_current_data['id'] = ith_dict_data['id']
            dict_current_data['PID'] = ith_dict_data['PID']
            list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + ['pos', self.tgt_var_name, 'id'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug



    def make_dict_from_arr_data(self, arr):
        assert len(arr) == len(self.covariate_names) * 2 or len(arr) == len(self.covariate_names)
        if len(arr) == len(self.covariate_names):
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
        elif len(arr) == len(self.covariate_names) * 2:
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx] + '_muter'] = arr[idx + len(self.covariate_names)]

        return dict_data



    def read_data(self, filename_datasource, split):
        # read data
        types = {"id": str, "WEIGHT": float, "AGE": float, "SEX": float, "HEIGHT": float, "csa": float, "pos": float, 'PID': str} #str, id="str", weight="float")
        self.df_data = pd.read_excel(filename_datasource, header=0, dtype=types)
        # select needed columns and rows
        df_data_split = self.df_data.loc[self.df_data['id'].isin(split)]
        df_data_split = df_data_split[self.used_columns]
        df_data_split = df_data_split.dropna(subset=['id', self.tgt_var_name, 'pos'])
        return df_data_split

    def load_yaml_as_dict(self, yaml_path):
        import yaml
        with open(yaml_path, 'r') as f:
            config_dict = yaml.load(f, Loader=yaml.FullLoader)
        return config_dict

    def __getitem__(self, idx):

        # loading features
        list_covariates = []
        for ith_cov in range(len(self.covariate_names)):
            list_covariates.append(torch.tensor(self.prepared_data[self.covariate_names[ith_cov]].iloc[idx]).float())

        coords = torch.tensor(self.prepared_data['pos'].iloc[idx]).float()

        if self.padding_muter:
            # loading muters
            list_covariate_muters = []
            for ith_cov in range(len(self.covariate_names)):
                list_covariate_muters.append(
                    torch.tensor(self.prepared_data[self.covariate_names[ith_cov] + '_muter'].iloc[idx]).float())
            model_input = torch.tensor([coords] + list_covariates + list_covariate_muters).float()
        else:
            model_input = torch.tensor([coords] + list_covariates).float()

        csa = torch.tensor(self.prepared_data[self.tgt_var_name].iloc[idx]).float()[None,...]

        return model_input, csa





if __name__ == "__main__":
    ds_train = OODAirwayDataset(
        filename_atlas_datasource="/playpen-raid/Author/LucidAtlas/data/airways/csa_20250118.xlsx",
        filename_atlas_split="/playpen-raid/Author/LucidAtlas/data/airways/airway_split_with_val.yaml",
        filename_ood_datasource="/playpen-raid/Author/LucidAtlas/data/airways/csa_sgs.xlsx",
        filename_ood_split="/playpen-raid/Author/LucidAtlas/data/airways/sgs_split_test.yaml",
        covariate_names=["AGE", "WEIGHT", "HEIGHT"],
        split='test', allow_missingness=False,padding_muter=False)
    print('1')
