import numpy as np
import os.path
from datasets.dataset import *

BASE_URL_SEX = os.path.join('datasets', 'diabetes', 'diabetes_processed_sex.csv')
BASE_URL_RACE = os.path.join('datasets', 'diabetes', 'diabetes_processed_race.csv')


def load_s(r_train=0.4, r_candidate=0.2, seed=None, include_intercept=True, use_pct=1.0, include_R=False, include_S=False, standardize=False, R0=None, R1=None, shuffle=True):
    meta_information = {
        'standardized': standardize,
        'include_R': include_R,
        'include_S': include_S,
        'include_intercept': include_intercept
    }
    if not (R0 is None and R1 is None):
        meta_information.update({'R0': R0, 'R1': R1})

    random = np.random.RandomState(seed)

    with open(BASE_URL_SEX, 'r') as f:
        raw = list(f)
    labels, *raw = [d.strip().split(',') for d in raw]
    data = {k: np.array(v).astype(float)
            for k, v in zip(labels, np.array(raw).T)}
    
    race_keys = [k for k in labels if k.startswith('race:')]
    race_labels = [k.split(':is_')[1] for k in race_keys]
    race_encodings = np.array([data[k] for k in race_keys]).T
    race_codes = [np.where(s)[0] for s in race_encodings]
    # for some reason, some records have no race associated with them
    I = np.array([len(c) <= 1 for c in race_codes])


    data = {k: v[I] for k, v in data.items()}
    R = np.array([race_codes[i][0] for i, keep in enumerate(I) if keep])
    S = data['gender:is_Female'].astype(int)
    Y = data['readmitted']
    Y[Y == 0] = -1

    feature_keys = [k for k in labels if not (k.startswith('race:')) and not (
        k in ['gender:is_Female', 'readmitted', 'weight:is_?', 'medical_specialty:is_?'])]

    X = np.array([data[k] for k in feature_keys]).T

    meta_information.update({
        'race_codes': race_labels,
        'sex_codes': ['female', 'male'],
        'feature_labels': feature_keys
    })

    # Reduce the dataset size as needed
    n_keep = int(np.ceil(len(X) * use_pct))
    I = np.arange(len(X))
    if shuffle:
        random.shuffle(I)
    I = I[:n_keep]
    X = X[I]
    Y = Y[I].flatten()
    S = S[I].flatten()
    R = R[I].flatten()

    # Filter out samples for races if R0 and/or R1 is specified
    if not (R1 is None and R0 is None):
        race_labels = [l.lower() for l in race_labels]
        if not (R0 is None) and not (R1 is None):
            i0 = race_labels.index(R0.lower())
            i1 = race_labels.index(R1.lower())
            I = np.logical_or(R == i0, R == i1)
            X = X[I]
            Y = Y[I]
            S = S[I]
            R = R[I]
            R = 0*(R == i0) + 1*(R == i1)
        elif not (R0 is None):
            i0 = race_labels.index(R0.lower())
            R = 0*(R == i0) + 1*(R != i0)
        elif not (R1 is None):
            i1 = race_labels.index(R1.lower())
            R = 0*(R != i1) + 1*(R == i1)

    # Compute split sizes
    n_samples = len(X)
    n_train = int(r_train*n_samples)
    n_test = n_samples - n_train
    n_candidate = int(r_candidate*n_train)
    n_safety = n_train - n_candidate

    if standardize:
        X = standardized(X)
    if include_R:
        X = with_feature(X, R)
        meta_information['feature_labels'].append('race')
    if include_S:
        X = with_feature(X, S)
        meta_information['feature_labels'].append('sex')
    if include_intercept:
        X = with_intercept(X)
        meta_information['feature_labels'].append('intercept')

    contents = {'X': X, 'Y': Y, 'R': R, 'S': S}
    all_labels = [0, 1]
    return ClassificationDataset(all_labels, n_candidate, n_safety, n_test, seed=seed, meta_information=meta_information, **contents)


def load_r(r_train=0.4, r_candidate=0.2, seed=None, include_intercept=True, use_pct=1.0, include_R=False, include_S=False, standardize=False, shuffle=True):
    meta_information = {
        'standardized': standardize,
        'include_R': include_R,
        'include_S': include_S,
        'include_intercept': include_intercept
    }

    random = np.random.RandomState(seed)

    with open(BASE_URL_RACE, 'r') as f:
        raw = list(f)
    labels, *raw = [d.strip().split(',') for d in raw]
    data = {k: np.array(v).astype(float)
            for k, v in zip(labels, np.array(raw).T)}
    
 
    R = data['race'].astype(int)
    S = data['gender:is_Female'].astype(int)
    Y = data['readmitted']
    Y[Y == 0] = -1

    feature_keys = [k for k in labels if not (
        k in ['gender:is_Female', 'readmitted', 'weight:is_?', 'medical_specialty:is_?'])]

    X = np.array([data[k] for k in feature_keys]).T

    meta_information.update({
        'race_codes': ['hispanic', 'caucasian', 'africanamerican', 'asian', 'other'],
        'sex_codes': ['female', 'male'],
        'feature_labels': feature_keys
    })

    # Reduce the dataset size as needed
    n_keep = int(np.ceil(len(X) * use_pct))
    I = np.arange(len(X))
    if shuffle:
        random.shuffle(I)
    I = I[:n_keep]
    X = X[I]
    Y = Y[I].flatten()
    S = S[I].flatten()
    R = R[I].flatten()

    # Compute split sizes
    n_samples = len(X)
    n_train = int(r_train*n_samples)
    n_test = n_samples - n_train
    n_candidate = int(r_candidate*n_train)
    n_safety = n_train - n_candidate

    if standardize:
        X = standardized(X)
    if include_R:
        X = with_feature(X, R)
        meta_information['feature_labels'].append('race')
    if include_S:
        X = with_feature(X, S)
        meta_information['feature_labels'].append('sex')
    if include_intercept:
        X = with_intercept(X)
        meta_information['feature_labels'].append('intercept')

    contents = {'X': X, 'Y': Y, 'R': R, 'S': S}
    all_labels = [0, 1]
    return ClassificationDataset(all_labels, n_candidate, n_safety, n_test, seed=seed, meta_information=meta_information, **contents)