import os
import torch
import pandas as pd
import numpy as np
from preprocessing.imputation import Imputer
from sklearn.preprocessing import LabelEncoder
from .data_module import DataModule
from tqdm import tqdm
import pickle
from types import SimpleNamespace
from typing import Tuple, List, Dict
# from pandarallel import pandarallel

class KamirDataModule(DataModule):
    
    def __init__(self, 
        config: SimpleNamespace
        ) -> None:
        super().__init__(config)

    def load_data(self, 
                sheet_name: str ='KAMIR-V 1년 F-U DATA', 
                skiprows: List[int] = [0]
        ) -> pd.DataFrame:
        data_path = self.config.data.file_path.split('/')
        dir_path = '/'.join(data_path[:-1])
        file_name = data_path[-1]
        if os.path.isfile(f'{dir_path}/data.pt'):
            data = torch.load(f'{dir_path}/data.pt')
        else:
            data = pd.read_excel(f'{dir_path}/{file_name}', sheet_name=sheet_name, skiprows=skiprows, engine='openpyxl')
            torch.save(data, f'{dir_path}/data.pt')
        
        return data

    def prepare_data(self) -> Tuple[pd.DataFrame, pd.Series, List[str], List[str]]:

        if os.path.exists(self.config.data.dataset_path) and False:
            with open(self.config.data.dataset_path, 'rb') as f:
                dataset = pickle.load(f)
            
            return dataset['data'], dataset['label'], dataset['numeric_cols'], dataset['category_cols']
        
        data = self.load_data()

        category_cols = data.columns[self.config.data.category_cols_idx].to_list()
        binary_cols = data.columns[self.config.data.binary_cols_idx].to_list()
        numeric_cols = data.columns[self.config.data.numeric_cols_idx].to_list()

        data = self.str_to_float(data, numeric_cols)
        data = self.abnormal2nan(data)
        label = self.get_label(data)

        data = self.organize_initial_diagnosis(data)

        data = self.organize_sex(data)

        data = data[category_cols + binary_cols + numeric_cols]

        data, drop_cols = self.drop_over_missing(data) 
        
        binary_cols, numeric_cols, category_cols = self.get_remained_cols(binary_cols, numeric_cols, category_cols, drop_cols)

        if self.config.data.LVEF_40 == 'all':
            pass
        elif self.config.data.LVEF_40 == 'less':
            labels = labels[data['LVEF'] < 40]
        else:
            labels = labels[data['LVEF'] >= 40]
        
        imputer = Imputer()
        imputed = imputer.impute(data, binary_cols, numeric_cols, category_cols)

        data, category_cols = self.to_onehot(imputed, category_cols)


        le = LabelEncoder()

        for col in category_cols:
            data[col] = le.fit_transform(data[col])

        
        if hasattr(self.config.data, 'rename_cols'):
            data.rename(columns=self.config.data.rename_cols, inplace=True)
            for idx, c in enumerate(numeric_cols):
                if c in self.config.data.rename_cols.keys():
                    numeric_cols[idx] = self.config.data.rename_cols[c]
            
            for idx, c in enumerate(category_cols):
                if c in self.config.data.rename_cols.keys():
                    category_cols[idx] = self.config.data.rename_cols[c]

        if self.config.runner_option.save_data:
            self.save_data(data, label, numeric_cols, category_cols)
            
        return data, label, numeric_cols, category_cols

    def get_6M(self, 
                data: pd.DataFrame
        ) -> pd.Series:
        label_6M = data.iloc[:, 320]
        label_6M = label_6M.apply(lambda x : 1 if x == 'Death' else 0)
        return label_6M
    
    def get_12M(self, 
                data: pd.DataFrame
        ) -> pd.Series:
        label_12M_idx = data['1 Year Follow-up'].apply(bool)
        label_12M = (data.loc[label_12M_idx, '12M_Cardiac death'] == 1) | (data.loc[label_12M_idx, '12M_Non-cardiac death'] == 1)
        label_12M = label_12M.apply(lambda x : 1 if x == True else 0)
        return label_12M
    
    def get_late(self, 
                data: pd.DataFrame
        ) -> pd.Series:
        label_6M = self.get_6M(data)

        label_12M_idx = data['1 Year Follow-up'].apply(bool)

        label_late_idx = label_12M_idx & (label_6M == 0)
        label_late = (data.loc[label_late_idx, '12M_Cardiac death'] == 1) | (data.loc[label_late_idx, '12M_Non-cardiac death'] == 1)
        label_late = label_late.parallel_apply(lambda x : 1 if x == True else 0).apply(np.int8)

        return label_late
    
    def get_reverse_remodeling(self, 
                                data: pd.DataFrame
        ) -> pd.Series:
        LVEF = self.percentage_to_float(data, 'LVEF')
        LVEDV = self.percentage_to_float(data, 'LVEDV')
        LVESV = self.percentage_to_float(data, 'LVESV')

        LVEF_12M = self.percentage_to_float(data, '12M_LVEF')
        LVEDV_12M = self.percentage_to_float(data, '12M_LVEDV')
        LVESV_12M = self.percentage_to_float(data, '12M_LVESV')

        LVEF_gap = (LVEF_12M - LVEF)[(~(LVEF - LVEF_12M).isna() | ~(LVEDV - LVEDV_12M).isna() | ~(LVESV - LVESV_12M).isna())].apply(lambda x : 1 if (x >= 10) & (not np.isnan(x)) else 0)
        LVEDV_gap = (LVEDV_12M - LVEDV)[(~(LVEF - LVEF_12M).isna() | ~(LVEDV - LVEDV_12M).isna() | ~(LVESV - LVESV_12M).isna())].apply(lambda x : 1 if (x <= 10) & (not np.isnan(x)) else 0)
        LVESV_gap = (LVESV_12M - LVESV)[(~(LVEF - LVEF_12M).isna() | ~(LVEDV - LVEDV_12M).isna() | ~(LVESV - LVESV_12M).isna())].apply(lambda x : 1 if (x <= 10) & (not np.isnan(x)) else 0)

        reverse_remodeling = (LVEF_gap + LVEDV_gap + LVESV_gap).apply(lambda x : 1 if x > 0 else x)

        return reverse_remodeling

    def get_label(self, 
                    data: pd.DataFrame
        ) -> pd.Series:
        if self.config.data.target == '6M':
            return self.get_6M(data)
        elif self.config.data.target == '12M':
            return self.get_12M(data)
        elif self.config.data.target == 'late':
            return self.get_late(data)
        elif self.config.data.target == 'reverse_remodeling':
            return self.get_reverse_remodeling(data)

    def str_to_float(self, 
                        data: pd.DataFrame, 
                        target_cols: List[str]
        ) -> pd.DataFrame:

        for col in target_cols:
            data[col] =  data[col].apply(lambda x : x.split('|') if type(x) == str else x)
            data[col] = data[col].apply(lambda x : float('.'.join([_.strip().strip('.') for _ in x])) if type(x) == list else x)
        return data

    def percentage_to_float(self, 
                            data: pd.DataFrame, 
                            target_col: str
        ) -> pd.Series:
        LVEs =  data[target_col].apply(lambda x : x.split('|') if type(x) == str else x)
        return LVEs.apply(lambda x : float('.'.join([_.strip().strip('.') for _ in x])) if type(x) == list else x)
    
    def get_bmi(self, 
                data: pd.DataFrame
        ) -> pd.Series:
        bmi = data['WT'] / (data['HT'] / 100) ** 2
        return bmi

    def abnormal2nan(self, 
                    data: pd.DataFrame
        ) -> pd.DataFrame:
        for key in self.config.data.abnormal_numerics.keys():
            limits = self.config.data.abnormal_numerics[key]
            data[key][(data[key] < limits[0]) | (data[key] > limits[1])] = np.nan
        
        bmi = self.get_bmi(data)
        bmi_idx = ((bmi < 10) | (bmi > 50))
        data['WT'][bmi_idx] = np.nan
        data['HT'][bmi_idx] = np.nan
        
        wrong_idx = (data['LVEDD'] <= data['LVESD']) | (data['LVEDD'] > 100) | (data['LVEDD'] < 0) | (data['LVESD'] > 100) | (data['LVESD'] < 0)
        data['LVEDD'][wrong_idx] = np.nan
        data['LVESD'][wrong_idx] = np.nan

        wrong_idx = (data['LVEDV'] <= data['LVESV']) | (data['LVEDV'] > 500) | (data['LVEDV'] < 0) | (data['LVESV'] > 500) | (data['LVESV'] < 0)
        data['LVEDD'][wrong_idx] = np.nan
        data['LVESD'][wrong_idx] = np.nan
        
        return data

    def organize_initial_diagnosis(self, 
                                    data: pd.DataFrame
        ) -> pd.DataFrame:
        for i in range(len(data)):
            if data['Initial diagnosis'][i] == "NSTEMI" and np.isnan(data['STEMI'][i]):
                data['STEMI'][i] = 'N'
            elif data['Initial diagnosis'][i] == "STEMI" and np.isnan(data['NSTEMI'][i]):
                data['NSTEMI'][i] = 'N'
            
        return data
    
    def organize_sex(self, 
                    data: pd.DataFrame
        ) -> pd.DataFrame:
        data['Sex'] = data['Sex'].apply(lambda x : 0 if x == '여' else 1)
        return data
    
    def get_missing_ratio(self, 
                        data: pd.DataFrame
        ) -> pd.Series:
        res = {}
        for col in data.columns:
            res[col] = data[col].isnull().sum() / len(data)
        return res

    def drop_over_missing(self, 
                            data: pd.DataFrame, 
                            col_dict: Dict[str, str] = None
        ) -> Tuple[pd.DataFrame, List[str]]:
        missing_ratio = self.get_missing_ratio(data)

        drop_cols = []
        for k in missing_ratio:
            if missing_ratio[k] > self.config.data.allowed_missing:
                drop_cols.append(k)
                if col_dict != None:
                    print(k, col_dict[k], "| %d%% |" % (round(missing_ratio[k] * 100)))
        
        data = data.drop(drop_cols, axis=1)

        return data, drop_cols
    
    def get_remained_cols(self, 
                        binary_cols: List[str], 
                        numeric_cols: List[str], 
                        category_cols: List[str], 
                        drop_cols: List[str]
        ) -> Tuple[List[str], List[str], List[str]]:
        binary = []
        for col in binary_cols:
            if not col in drop_cols:
                binary.append(col)
        
        numeric = []
        for col in numeric_cols:
            if not col in drop_cols:
                numeric.append(col)
        
        category = []
        for col in category_cols:
            if not col in drop_cols:
                category.append(col)
        
        return binary, numeric, category
    
    def to_onehot(self, 
                    data: pd.DataFrame, 
                    category_cols: List[str]
        ) -> Tuple[pd.DataFrame, List[str]]:
    
        onehot_cols = self.config.data.onehot_cols
        
        for k in onehot_cols.keys():
            for i in range(len(onehot_cols[k])):
                onehot_cols[k][i] = k + '|-|' + onehot_cols[k][i]

        final_onehot_cols = []
        for k in onehot_cols.keys():
            final_onehot_cols += onehot_cols[k]
        
        temp = pd.DataFrame(np.zeros((len(data), len(final_onehot_cols))), columns=final_onehot_cols, dtype=np.int8)
        
        for i in tqdm(range(len(data))):
            for col in temp.columns:
                for k in onehot_cols.keys():
                    if col.split('|-|')[1] in data[k][i]:
                        temp[col][i] = 1
                        break
        
        data = data.merge(temp, left_index=True, right_index=True)
        data.drop(onehot_cols.keys(), inplace=True, axis=1)

        for k in onehot_cols.keys():
            category_cols.remove(k)
        
        data[category_cols] = data[category_cols].astype("str")

        return data, category_cols