import os, sys
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import pandas as pd
import pickle
from parse_args_tr import parse_args, load_configs_tr
from train_medical_gruode_rl_synthesis import get_treatment_var_id_suffix, random_split_train_valid_test_ids
from create_language_tr import *
import synthetic_lang_tr
import GRU_ODE.data_utils as data_utils
from scipy.cluster import hierarchy


def merge_all_df(all_data, cat_feat_ls, num_feat_ls):
    df_ls = []
    for pat_id, df in all_data.r_data.items():
        df_ls.append(df[[*cat_feat_ls, *num_feat_ls]])
    df = pd.concat(df_ls)
    df = df.drop_duplicates()
    return df

class UnionFind:
    def __init__(self, n):
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)

        if root_x != root_y:
            if self.rank[root_x] > self.rank[root_y]:
                self.parent[root_y] = root_x
            elif self.rank[root_x] < self.rank[root_y]:
                self.parent[root_x] = root_y
            else:
                self.parent[root_y] = root_x
                self.rank[root_x] += 1


def get_correlated_attributes(correlation_matrix, threshold):
    """
    Returns a set of correlated attributes from the given DataFrame based on the specified threshold.

    Parameters:
        dataframe (pd.DataFrame): The input DataFrame containing numeric attributes.
        threshold (float): The correlation threshold to filter attributes (e.g., 0.8 for 80% correlation).

    Returns:
        set: A set containing the names of correlated attributes.
    """
    # Calculate the correlation matrix

    # Filter attributes based on the threshold


    dist = np.abs(1 - correlation_matrix.values)
    max_dist = np.max(dist[dist == dist])
    dist[dist != dist] = max_dist + 1
    # Step 5: Define similarity threshold
    # Step 6 and 7: Perform hierarchical clustering and visualize dendrogram
    linkage = hierarchy.linkage(dist, method='ward')
    clusters = hierarchy.fcluster(linkage, t=threshold, criterion='distance')
    grouped_columns = {}
    for col, cluster_id in zip(correlation_matrix.columns, clusters):
        if cluster_id not in grouped_columns:
            grouped_columns[cluster_id] = []
        grouped_columns[cluster_id].append(col)

    # Convert the dictionary to a list of lists for easier access
    column_groups = list(grouped_columns.values())
    return column_groups
    # n_features = len(correlation_matrix.columns)
    # uf = UnionFind(n_features)

    # # Union-Find Algorithm to group correlated attributes
    # for i in range(n_features):
    #     for j in range(i):
    #         # if np.isnan(correlation_matrix.iloc[i, j]):
    #         #     print()
    #         if list(correlation_matrix.columns)[i] == "RACE" and list(correlation_matrix.columns)[j] == "SEX_C.x":
    #             print()
    #         if (not np.isnan(correlation_matrix.iloc[i, j])) and abs(correlation_matrix.iloc[i, j]) >= threshold:
    #             assert correlation_matrix.iloc[i, j] == correlation_matrix.iloc[i, j]
    #             uf.union(i, j)

    # # Collect the attribute groups
    # attribute_groups = {}
    # for i in range(n_features):
    #     root = uf.find(i)
    #     if root not in attribute_groups:
    #         attribute_groups[root] = set()
    #     attribute_groups[root].add(correlation_matrix.columns[i])

    # return list(attribute_groups.values())


    # correlated_attributes = set()
    # for i in range(len(correlation_matrix.columns)):
    #     for j in range(i):
    #         if abs(correlation_matrix.iloc[i, j]) >= threshold:
    #             attribute_i = correlation_matrix.columns[i]
    #             attribute_j = correlation_matrix.columns[j]
    #             correlated_attributes.add(attribute_i)
    #             correlated_attributes.add(attribute_j)

    # return correlated_attributes

def find_correlated_attributes(num_feat_ls, cat_feat_ls, all_df):
    num_feat_corr = all_df[num_feat_ls].corr()
    cat_feat_corr = all_df[cat_feat_ls].corr()
    return num_feat_corr, cat_feat_corr


def reduce_feat_space_final(all_correlated_attribute_set_ls, all_df, df_by_pat_mapping, numer_feat_ls, cat_feat_ls, id_attr="PAT_ID", time_attr="num_days", treatment_attr="concat_treatment_label", outcome_attr="label"):
    reduced_numer_feat_ls = []
    reduced_cat_feat_ls = []
    reduced_cat_onehot_feat_ls = []
    for attr_set in all_correlated_attribute_set_ls:
        first_attr = list(attr_set)[0]
        if first_attr in numer_feat_ls:
            reduced_numer_feat_ls.append(first_attr)
        elif first_attr in cat_feat_ls:
            reduced_cat_feat_ls.append(first_attr)
            reduced_cat_onehot_feat_ls.append(first_attr + "_onehot")
    
    all_feat = [*reduced_cat_feat_ls, *reduced_numer_feat_ls, *reduced_cat_onehot_feat_ls]
    all_feat.append(id_attr)
    all_feat.append(time_attr)
    all_feat.append(treatment_attr)
    all_feat.append(outcome_attr)
    all_feat.append("concat_treatment_label_id")
    for pat in df_by_pat_mapping:
        df_by_pat_mapping[pat] = df_by_pat_mapping[pat][all_feat]
    
    return df_by_pat_mapping, reduced_numer_feat_ls, reduced_cat_feat_ls


if __name__ == "__main__":
    args = parse_args()
    # if args.demo:
    #     print(f"Demo Mode - Loading model for double_OU ....")
    #     gru_ode_bayes.paper_plotting.plot_traqined_model(model_name = "double_OU_gru_ode_bayes_demo")
    #     exit()
    program_max_len = args.program_max_len

    treatment_var_ids_str_suffix = get_treatment_var_id_suffix(args)

    with open(os.path.join(args.log_folder, "processed_treatment_data" + treatment_var_ids_str_suffix), "rb") as f:
        df_by_pat_mapping = pickle.load(f)
    with open(os.path.join(args.log_folder, "feat_to_onehot_embedding" + treatment_var_ids_str_suffix), "rb") as f:
        feat_to_onehot_embedding = pickle.load(f)
    with open(os.path.join(args.log_folder, "numer_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
        numer_feat_ls = pickle.load(f)
    with open(os.path.join(args.log_folder, "cat_feat_ls" + treatment_var_ids_str_suffix), "rb") as f:
        cat_feat_ls = pickle.load(f)
    with open(os.path.join(args.log_folder, "treatment_var_ls" + treatment_var_ids_str_suffix), "rb") as f:
        treatment_var_ls = pickle.load(f)
    
    with open(os.path.join(args.log_folder, "unique_treatment_label_ls" + treatment_var_ids_str_suffix), "rb") as f:
        unique_treatment_label_ls = pickle.load(f)
        
    with open(os.path.join(args.log_folder, "feat_range_mappings" + treatment_var_ids_str_suffix), "rb") as f:
        feat_range_mappings = pickle.load(f)
        
    normalized_feat_range_mappings = dict()
    for key in feat_range_mappings:
        normalized_feat_range_mappings[key] = [0,1]

    val_options = {"T_val": 0.6, "max_val_samples": args.max_val_samples}
    lang = Language(numer_feat_ls, cat_feat_ls, lang=synthetic_lang_tr)
    # train_ids,  valid_ids, test_ids = random_split_train_valid_test_ids(df_by_pat_mapping)
    pat_ids = list(df_by_pat_mapping.keys())
    all_data = data_utils.ODE_Dataset_medical_rl(synthetic_lang_tr, numer_feat_ls, cat_feat_ls, df_by_pat_mapping, pat_ids, feat_range_mappings=normalized_feat_range_mappings, test=args.test)
    
    all_df = merge_all_df(all_data, cat_feat_ls, numer_feat_ls)
    num_feat_corr, cat_feat_corr = find_correlated_attributes(numer_feat_ls, cat_feat_ls, all_df)

    num_correlated_attribute_set_ls = get_correlated_attributes(num_feat_corr, threshold=0.9)
    # cat_correlated_attribute_set_ls = get_correlated_attributes(cat_feat_corr, threshold=0.7)
    all_correlated_attribute_set_ls = num_correlated_attribute_set_ls + [[attr] for attr in cat_feat_corr]# get_correlated_attributes(all_df.corr(), threshold=0.7)
    reduced_df_by_pat_mapping, reduced_numer_feat_ls, reduced_cat_feat_ls = reduce_feat_space_final(all_correlated_attribute_set_ls, all_df, df_by_pat_mapping, numer_feat_ls, cat_feat_ls)

    with open(os.path.join(args.log_folder, "reduced_processed_treatment_data" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(reduced_df_by_pat_mapping, f)
    with open(os.path.join(args.log_folder, "reduced_numer_feat_ls" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(reduced_numer_feat_ls, f)
    with open(os.path.join(args.log_folder, "reduced_cat_feat_ls" + treatment_var_ids_str_suffix), "wb") as f:
        pickle.dump(reduced_cat_feat_ls, f)
        # df_by_pat_mapping = pickle.load(f)
    print()