import pandas
import os, sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
from mortalty_prediction.full_experiments.parse_args import load_configs
from parse_args_tr import parse_args
import torch
import random
import numpy as np
from mortalty_prediction.datasets.EHR_datasets import EHRDataset, read_cancer_data, dataset_name_mappings
import pandas as pd
from datetime import datetime
# import mortalty_prediction.full_experiments.synthetic_lang_tr as synthetic_lang_tr
import synthetic_lang_tr as synthetic_lang_tr
import torch.nn.functional as F
import pickle
import math
from tqdm import tqdm

pd.options.mode.chained_assignment = None  # default='warn'

# treatment_labels_to_be_merged = [['Platinum chemotherapy_indicator', 'Platinum Chemotherapy_indicator'], ['Immunotherapy - PD1/PDL1_indicator', 'Immunotherapy - PD1/PDL1 inhibitor_indicator', 'Immunotherapy - CTLA-4 inhibitor_indicator'], ['Somatostatin Analogue_indicator', 'Somatostatin analogue_indicator']]
treatment_labels_to_be_merged = [['Platinum Chemotherapy_indicator', 'Platinum Chemotherapy_indicator', 'Non-platinum chemotherapy_indicator', 'Other chemotherapy_indicator']]

def read_treatment_data(data_folder, dataset_name):
    data_path = os.path.join(data_folder, dataset_name_mappings[dataset_name])
    cancer_df = pd.read_csv(data_path)
    # cancer_df = cancer_df.drop("")
    # filtered_clns = [key for key in cancer_df.columns if key.endswith("mean") or key.endswith("std") or key.endswith("count") or key.endswith("last") or key.endswith("first")]
    # for cln in cancer_df.columns:
    #     print(cln, cancer_df[cln].unique())
    cancer_df = cancer_df.drop(columns = [cancer_df.columns[0]])
    return cancer_df


def get_numeric_feats(cancer_df):
    numeric_feat_ls = []
    for cln in cancer_df.columns:
        if "..last" in cln or "..first" in cln or "..std" in cln or "..max" in cln or "..min" in cln or "..mean" in cln or "n_is" in cln or "..count" in cln:
            numeric_feat_ls.append(cln)
    numeric_feat_ls.extend(synthetic_lang_tr.OTHER_SCORE_FEATS)
    for cln in numeric_feat_ls:
        cancer_df[cln] = cancer_df[cln].apply(lambda x: np.nan if x == "Incomplete" else x)
        cancer_df[cln] = cancer_df[cln].astype(float)
    # numeric_feat_ls.append("PAT_AGE")
    return numeric_feat_ls

def get_categorical_feats(cancer_df):
    cat_feat_ls = []
    for cln in cancer_df.columns:
        if cln in synthetic_lang_tr.CAT_FEATS:
            cat_feat_ls.append(cln)
            # cat_feat_ls[cat_feat_ls[cln].str.lower()=="none", cln] = np.nan
        else:
            print(cln, " not exists")
    
    return cat_feat_ls

def get_categorical_feats_onehot(cancer_df, cat_feat_ls):
    feat_to_curr_onehot_embedding = dict()
    feat_to_unique_vals_mappings = dict()
    for feat in cat_feat_ls:
        unique_vals_curr_feat = list(pd.unique(cancer_df[feat]))
        if np.nan in unique_vals_curr_feat:
            unique_vals_curr_feat.remove(np.nan)
        if "None" in unique_vals_curr_feat:
            unique_vals_curr_feat.remove("None")

        curr_onehot_embedding = dict()
        for idx in range(len(unique_vals_curr_feat)):
            curr_onehot_embedding[unique_vals_curr_feat[idx]] = torch.zeros(len(unique_vals_curr_feat))
            curr_onehot_embedding[unique_vals_curr_feat[idx]][idx] = 1
            curr_onehot_embedding[unique_vals_curr_feat[idx]] = curr_onehot_embedding[unique_vals_curr_feat[idx]].tolist()
        feat_to_curr_onehot_embedding[feat] = curr_onehot_embedding
        feat_to_unique_vals_mappings[feat] = unique_vals_curr_feat
    return feat_to_curr_onehot_embedding, feat_to_unique_vals_mappings


def transform_cat_feat_to_one_hot(curr_df, cat_feat_ls, feat_to_curr_onehot_embedding, feat_to_unique_vals_mappings):
    for feat in cat_feat_ls:
        curr_onehot_embedding = feat_to_curr_onehot_embedding[feat]
        val_ls = []
        for idx in range(len(curr_df)):
            curr_val = curr_df.iloc[idx][feat]
            if curr_val is None or (type(curr_val) is not str and math.isnan(curr_val)) or curr_val not in feat_to_unique_vals_mappings[feat]:
                val_ls.append([None]*len(list(curr_onehot_embedding.values())[0]))
            else:
                val_ls.append(curr_onehot_embedding[curr_val])
                
        curr_df[feat + "_onehot"] = val_ls.copy()
    return curr_df

def rescale_numeric_feats(curr_df, feat_to_min, feat_to_max):
    for feat in feat_to_min:
        curr_df.loc[~curr_df[feat].isna(), feat] = (curr_df.loc[~curr_df[feat].isna(),feat] - feat_to_min[feat]) / (feat_to_max[feat] - feat_to_min[feat])
    return curr_df

def get_single_feat_range(df, numeric_feat_ls):
    feat_to_min = dict()
    feat_to_max = dict()
    feat_range_mappings = dict()
    for feat in numeric_feat_ls:
        feat_to_min[feat] = df[feat].min()
        feat_to_max[feat] = df[feat].max()
        feat_range_mappings[feat] = [df[feat].min(), df[feat].max()]
    return feat_to_min, feat_to_max, feat_range_mappings

def get_treatment_feats(cancer_df):
    treatment_feats = []
    for cln in list(cancer_df.columns):
        if cln.endswith("indicator"):
            treatment_feats.append(cln)
    return treatment_feats

def get_subset_pid(pat_ids, ratio=0.1):
    pat_ids = list(pat_ids)
    pat_ids.sort()
    random.shuffle(pat_ids)
    subset_pat_ids = pat_ids[:int(len(pat_ids)*ratio)]
    return subset_pat_ids


def get_single_pat_data(cancer_df, treatment_var_ids=[1], ratio=0.1, type = "treat_VS_non_treat"):
    cancer_df["PAT_ID"] = cancer_df["PAT_ID"].astype(str)
    pat_ids = set(list(cancer_df["PAT_ID"]))
    subset_pat_ids = get_subset_pid(pat_ids, ratio=ratio)
    df_by_pat = dict()
    numer_feat_ls = get_numeric_feats(cancer_df)
    cat_feat_ls = get_categorical_feats(cancer_df)
    feat_to_curr_onehot_embedding, feat_to_unique_vals_mappings = get_categorical_feats_onehot(cancer_df, cat_feat_ls)
    feat_to_min, feat_to_max, feat_range_mappings = get_single_feat_range(cancer_df, numer_feat_ls)
    all_max_day = -np.inf
    treatment_feats = get_treatment_feats(cancer_df)
    if treatment_var_ids is not None:
        treatment_var = [treatment_feats[idx] for idx in range(len(treatment_var_ids))]
    # treatment VS non-treatment
    else:
        treatment_var = treatment_feats
       
    no_treatment_label = "No_treatment" 
    if no_treatment_label in list(cancer_df.columns):
        treatment_feats.append(no_treatment_label)

    all_feats = set(cancer_df.columns) 
    expect_feats = set(numer_feat_ls).union(set(cat_feat_ls)).union(synthetic_lang_tr.DROP_FEATS).union(synthetic_lang_tr.DATE_FEATS).union(synthetic_lang_tr.UNKNOWN).union(synthetic_lang_tr.OTHER_FEATS).union(synthetic_lang_tr.Prob_FEATS).union(set(treatment_feats))
    
    remaining_feats = all_feats.difference(expect_feats)
    assert len(remaining_feats) == 0

    unique_treatment_label_set = set()
    negative_pat_count = 0
    for pid in tqdm(subset_pat_ids):
        curr_df = cancer_df[cancer_df["PAT_ID"] == pid]
        # unique_lines = set(list(curr_df["line"]))
        # for line in unique_lines:
            # sub_df = curr_df[curr_df["line"] == line]
        time_steps = [datetime.strptime(list(curr_df["date"])[idx].split(" ")[0].strip(), "%Y-%m-%d") for idx in range(len(curr_df["date"]))]
        app_time = list(curr_df["APPT_TIME"])
        a_time_ls = []

        # for a_time in app_time:
        #     a_time_ls.append(datetime.strptime(a_time, "%Y-%m-%d %H:%M:%S"))

        # first_date = datetime.strptime(curr_df["FirstTreatmentDate"].min(), "%Y-%m-%d")
        first_date = datetime.strptime(curr_df["date"].min().split(" ")[0].strip(), "%Y-%m-%d")
        
        # last_date = datetime.strptime(list(curr_df["LastTreatmentDate"])[0], "%Y-%m-%d")
        delta = [t - first_date for t in time_steps]
        curr_df["num_days"] = [d.days for d in delta]
        if pd.isnull(curr_df["DEATH_DATE"]).all():
            curr_df["label"]=[0]*len(curr_df)
        else:    
            death_date = list(curr_df["DEATH_DATE"].dropna())[0].split(" ")[0].strip()
            days_before_death = [datetime.strptime(death_date, "%Y-%m-%d") - datetime.strptime(list(curr_df["date"])[idx].split(" ")[0].strip(), "%Y-%m-%d") for idx in range(len(curr_df))]
            curr_df["label"] = [0 if item.days > 180 else 1 for item in days_before_death]
        curr_df = transform_cat_feat_to_one_hot(curr_df, cat_feat_ls, feat_to_curr_onehot_embedding, feat_to_unique_vals_mappings)
        cat_feat_onehot_ls = [feat + "_onehot" for feat in cat_feat_ls]
        curr_df = rescale_numeric_feats(curr_df, feat_to_min, feat_to_max)
        if not type == "treat_VS_non_treat":
            curr_df["concat_treatment_label_ls"] = curr_df[treatment_var].apply(lambda x: [x[treatment_var[idx]] for idx in range(len(treatment_var))], axis=1)
            if treatment_var_ids is not None:
                curr_df["concat_treatment_label"] = curr_df[treatment_var].apply(lambda x: sum([x[treatment_var[idx]]*2**idx for idx in range(len(treatment_var))]), axis=1)
            # treatment VS non-treatment
            else:
                curr_df["concat_treatment_label"] = curr_df[treatment_var].apply(lambda x: sum([x[treatment_var[idx]] for idx in range(len(treatment_var))]) >= 1, axis=1)
                curr_df["concat_treatment_label"] = curr_df["concat_treatment_label"].astype(int)
                if curr_df["concat_treatment_label"].sum() != len(curr_df["concat_treatment_label"]):
                    negative_pat_count += 1
            
        else:
            curr_df["concat_treatment_label_ls"] = 1 - curr_df[no_treatment_label]
            curr_df["concat_treatment_label"] = 1 - curr_df[no_treatment_label]
        unique_treatment_label_set.update(list(curr_df["concat_treatment_label"].unique()))
        curr_time_series_seq = pd.concat([curr_df["PAT_ID"], curr_df["num_days"], curr_df[numer_feat_ls], curr_df[cat_feat_ls], curr_df[cat_feat_onehot_ls], curr_df["concat_treatment_label"], curr_df["label"]], axis=1)

            # num_treatments = list(sub_df["TreatedCycles"])[0]

            # gaps = num_days/num_treatments
        
        df_by_pat[pid] = curr_time_series_seq

        max_day = curr_df["num_days"].max()

        if max_day > all_max_day:
            all_max_day = max_day
    print("all_max_day::", all_max_day)
    unique_treatment_label_ls = list(unique_treatment_label_set)
    unique_treatment_label_mappings = {unique_treatment_label_ls[idx]: idx for idx in range(len(unique_treatment_label_ls))}
    
    for pid in tqdm(df_by_pat):
        df_by_pat[pid]["num_days"] = df_by_pat[pid]["num_days"]/all_max_day
        
        curr_time_series_seq = merge_along_time_dimension(df_by_pat[pid], pid, "PAT_ID", "num_days", numer_feat_ls, cat_feat_ls,  cat_feat_onehot_ls, feat_to_curr_onehot_embedding, "concat_treatment_label", "label")
        if type == "treat_VS_non_treat":
            curr_time_series_seq["concat_treatment_label_id"] = curr_time_series_seq["concat_treatment_label"]
        else:
            curr_time_series_seq["concat_treatment_label_id"] = curr_time_series_seq["concat_treatment_label"].apply(lambda x: unique_treatment_label_mappings[x])
        df_by_pat[pid] = curr_time_series_seq
    # print("negative patient count::", negative_pat_count)
    return df_by_pat, feat_to_curr_onehot_embedding, numer_feat_ls, cat_feat_ls, treatment_var, unique_treatment_label_ls, feat_range_mappings

    # for pid in pat_ids:
    #     cancer_df[]


def get_treatment_var_id_suffix(args):
    if args.treatment_var_ids is None:
        return "_all"
    args.treatment_var_ids.sort()
    
    treatment_var_ids_str_suffix = ""
    for idx in range(len(args.treatment_var_ids)):
        treatment_var_ids_str_suffix += "_"    
        treatment_var_ids_str_suffix += str(args.treatment_var_ids[idx])
    return treatment_var_ids_str_suffix

def merge_treatment_labels(cancer_df):
    for idx in range(len(treatment_labels_to_be_merged)):
        for k in range(1, len(treatment_labels_to_be_merged[idx])):
            cancer_df[treatment_labels_to_be_merged[idx][0]] += cancer_df[treatment_labels_to_be_merged[idx][k]]
            cancer_df[treatment_labels_to_be_merged[idx][0]] = (cancer_df[treatment_labels_to_be_merged[idx][0]].values >= 1).astype(int)
            cancer_df = cancer_df.drop(treatment_labels_to_be_merged[idx][k], axis=1)
    return cancer_df


def merge_along_time_dimension(df, pat_idx, id_attr, time_attr, num_attr_ls, cat_attr_ls, cat_onehot_attr_ls, feat_to_curr_onehot_embedding, treatment_attr, outcome_attr, gap=0.1):
    total_time = 0
    
    all_cols = []
    all_cols.append(id_attr)
    all_cols.append(time_attr)
    all_cols.extend(num_attr_ls)
    all_cols.extend(cat_attr_ls)
    all_cols.extend(cat_onehot_attr_ls)
    all_cols.append(treatment_attr)
    all_cols.append(outcome_attr)
    
    res_df = pd.DataFrame(columns=all_cols, index=range(int(1/gap)))
    
    res_df[id_attr] = pat_idx
    res_df[time_attr] = [k*gap for k in range(int(1/gap))]
    
    for idx in range(int(1/gap)):
        
        curr_sub_df = df[(df[time_attr] >= total_time) & (df[time_attr] < total_time + gap)]
        if len(curr_sub_df) <= 0:
            
            # for feat in cat_attr_ls:
            #     # for k in np.nonzero((res_df[time_attr] == total_time).values)[0]:
            #     res_df.at[idx, feat + "_onehot"] = [None]*len(list(feat_to_curr_onehot_embedding[feat].values())[0])
            # total_time += gap
            res_df = res_df.drop(idx)
            continue
        
        curr_num_feats = curr_sub_df[num_attr_ls].values
        curr_num_masks = ~np.isnan(curr_num_feats)
        curr_cat_feats = curr_sub_df[cat_attr_ls]        
        mean_num_feats = np.sum(curr_num_feats*curr_num_masks, axis=0)/np.sum(curr_num_masks, axis=0)
        res_df.loc[idx, num_attr_ls] = mean_num_feats
        all_cat_feat_ls = []
        for feat in cat_attr_ls:
            curr_cat_feat_vals = curr_cat_feats[feat].dropna()
            curr_cat_feat_vals = curr_cat_feat_vals[curr_cat_feat_vals != "None"]
            if len(curr_cat_feat_vals) > 0:
                curr_cat_feat_vals_mode = curr_cat_feat_vals.mode()[0]
                # for k in np.nonzero((res_df[time_attr] == total_time).values)[0]:
                res_df.at[idx, feat + "_onehot"] = feat_to_curr_onehot_embedding[feat][curr_cat_feat_vals_mode]
                res_df.at[idx, feat] = curr_cat_feat_vals_mode
            else:
                # for k in np.nonzero((res_df[time_attr] == total_time).values)[0]:
                res_df.at[idx, feat + "_onehot"] = [None]*len(list(feat_to_curr_onehot_embedding[feat].values())[0])
            
            
            
        
        
            
        curr_treat_feats = curr_sub_df[treatment_attr].values      
        if np.sum(curr_treat_feats) > 0:
            res_df.loc[idx, treatment_attr] = 1
        else:
            res_df.loc[idx, treatment_attr] = 0
        
        curr_outcome_feats = curr_sub_df[outcome_attr].values      
        if np.sum(curr_outcome_feats) > 0:
            res_df.loc[idx, outcome_attr] = 1
        else:
            res_df.loc[idx, outcome_attr] = 0
        
        
        
        total_time += gap
    
    return res_df

if __name__ == "__main__":
    args = parse_args()
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(args.seed)

    rl_config, model_config = load_configs(args)

    cancer_df = read_treatment_data(args.data_folder, args.dataset_name)
    # for chemotherapies
    # cancer_df = merge_treatment_labels(cancer_df)
    


    df_by_pat_mapping, feat_to_curr_onehot_embedding, numer_feat_ls, cat_feat_ls, treatment_var_ls, unique_treatment_label_ls, feat_range_mappings = get_single_pat_data(cancer_df, treatment_var_ids=args.treatment_var_ids, ratio=args.ratio)
    
    treatment_var_ids_str_suffix = get_treatment_var_id_suffix(args)
    
    with open(os.path.join(args.log_folder, "processed_treatment_data" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(df_by_pat_mapping, f)
    with open(os.path.join(args.log_folder, "feat_to_onehot_embedding" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(feat_to_curr_onehot_embedding, f)
    with open(os.path.join(args.log_folder, "numer_feat_ls" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(numer_feat_ls, f)
    with open(os.path.join(args.log_folder, "cat_feat_ls" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(cat_feat_ls, f)
    with open(os.path.join(args.log_folder, "treatment_var_ls" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(treatment_var_ls, f)
    with open(os.path.join(args.log_folder, "unique_treatment_label_ls" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(unique_treatment_label_ls, f)
    with open(os.path.join(args.log_folder, "feat_range_mappings" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(feat_range_mappings, f)
    # train_data, valid_data, test_data, _ = read_cancer_data(args.data_folder, dataset_name=args.dataset_name)
    print()
    
