import os
import numpy as np
from causalgraphicalmodels import CausalGraphicalModel
from expl_dist_shift.utils import flatten
from expl_dist_shift.graph_utils import verify_cgm
import pandas as pd
from sklearn.model_selection import train_test_split

class eICU():
    TARGET_NAME = 'death'
    VAR_CATEGORIES = {
        'Demo': [
            'is_female',
            'race_black', 
            'race_hispanic', 
            'race_asian', 
            'race_other',
        ],  # demographics
        'Vitals': [
            'heartrate',
            'sysbp',
            'temp',
            'bg_pao2fio2ratio',
            'urineoutput',
        ],  # vitals
        'Labs': [
            'bun',
            'sodium',
            'potassium',
            'bicarbonate',
            'bilirubin',
            'wbc',
            'gcs',
        ],  # labs
        'Age': [
            'age',
        ],  # miscellaneous
        'ElectiveSurgery': [
            'electivesurgery',
        ],  # miscellaneous
        'Outcome': ['death']
    }
    TASK_TYPE = 'classification'
    GRAPH = verify_cgm(CausalGraphicalModel(
        nodes= list(VAR_CATEGORIES.keys()),
        edges=[
            ('Demo', 'Outcome'),
            ('Vitals', 'Outcome'), 
            ('Labs', 'Outcome'),
            ('Age', 'Outcome'), 
            ('ElectiveSurgery', 'Outcome'), 
            ('Demo', 'Vitals'),
            ('Age', 'Vitals'), 
            # ('ElectiveSurgery', 'Vitals'),
            ('Demo', 'Labs'),
            ('Age', 'Labs'), 
            # ('ElectiveSurgery', 'Labs'),
            ('Age', 'ElectiveSurgery'),
        ]
    ))
    def __init__(self, hparams):
        self.hparams = hparams 
        self.data_seed = self.hparams['data_seed']
        self.n = self.hparams['n']
        self.test_pct = self.hparams['test_pct']
        self.source_hospital_id = self.hparams['source_hospital_id']
        self.target_hospital_id = self.hparams['target_hospital_id']
        rng = np.random.RandomState(self.data_seed)
        self.w = rng.normal(size = (3, 1))
        self.TRAIN_FEATURES = [j for i in self.VAR_CATEGORIES for j in self.VAR_CATEGORIES[i] if i!='Outcome']

    # def generate(self, n, rng, q = 0.9, y_noise = 0.25, mu_add = 5):
    #     X = rng.normal(size = (n, 3))        
    #     Y_orig = np.dot(X, self.w) >= 0
        
    #     Y = np.logical_xor(Y_orig, rng.random(size = (n, 1)) <= y_noise)    
    #     G = np.logical_xor(Y, rng.random(size = (n, 1)) >= q)   
        
    #     X[:, -1] = np.where((G == 1).squeeze(), X[:, -1] + np.random.normal(loc = mu_add, size = (n, )),
    #                         X[:, -1])
    #     return self.data_to_df(X, G.squeeze(), Y.squeeze())

    def generate(self, n, rng, hospital_id):
        df = pd.read_csv(os.path.join(self.hparams['data_dir'],'data_eicu_extract.csv'))
        df = df[df['hospitalid']==hospital_id]
        df = df[:n]
        cols = self.VAR_CATEGORIES['Demo'] + self.VAR_CATEGORIES['Vitals'] + self.VAR_CATEGORIES['Labs'] + self.VAR_CATEGORIES['Age'] + self.VAR_CATEGORIES['ElectiveSurgery'] + self.VAR_CATEGORIES['Outcome']
        df = df[cols]
        return df

    def get_source_train_test(self):
        rng = np.random.RandomState(self.data_seed)
        df = self.generate(self.n, rng, hospital_id=self.source_hospital_id)
        return train_test_split(df, random_state = self.data_seed, shuffle = True, test_size = self.test_pct)

    def get_target_train_test(self, shift_hparams):
        assert shift_hparams['data_seed'] != self.hparams['data_seed']
        rng = np.random.RandomState(shift_hparams['data_seed'])
        df = self.generate(shift_hparams['spu_n'], rng, hospital_id=self.target_hospital_id)
        return train_test_split(df, random_state = shift_hparams['data_seed'], shuffle = True, test_size = shift_hparams['test_pct'])
