#%%
from folktexts.acs import ACSDataset

from folktexts.acs import ACSTaskMetadata
import pandas as pd
import glest 
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeRegressor

from glest.plot import grouping_diagram_residuals, grouping_diagram_residuals_descriptive
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
from scipy.interpolate import griddata
import glob
import re
from matplotlib.patches import Patch
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import rcParams

from sklearn.metrics import accuracy_score, roc_auc_score
# %%


merged_df = pd.read_csv("folktexts-results/folktexts-results/model-Llama-3.3-70B-Instruct_task-ACSIncome/Llama-3.3-70B-Instruct_bench-3309924360/merged_acs_llama.csv")

# DATA_DIR = "notebooks/data"
# TASK_NAME = "ACSIncome"
# task = ACSTaskMetadata.get_task(TASK_NAME, use_numeric_qa=False)

# dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)
# # %%
# dataset.data
# # %%

# df = pd.read_csv("folktexts-results/folktexts-results/model-Llama-3.3-70B-Instruct_task-ACSIncome/Llama-3.3-70B-Instruct_bench-3309924360/ACSIncome_full_seed-42_hash-1998608642.test_predictions.csv")

# # %%
# features = dataset.get_features_data()
# # %%
# # Create a mapping between features and df indexes
# matched_features = features.loc[df["Unnamed: 0"]].copy()

# # %%
# test = df.set_index(df["Unnamed: 0"].values)
# # %%
# test.drop(columns=["Unnamed: 0"], inplace=True)
# # %%
# merged_df = pd.concat([matched_features, test], axis=1)
# # %%


# merged_df.to_csv("folktexts-results/folktexts-results/model-Llama-3.3-70B-Instruct_task-ACSIncome/Llama-3.3-70B-Instruct_bench-3309924360/merged_acs_llama.csv", index=False)
# %%
f = merged_df["risk_score"]
y = merged_df["label"]

merged_df.drop(columns=["risk_score", "label"], inplace=True)


#%%
# Plot mapping 

merged_df.drop(columns=["COW", "MAR", "OCCP", "POBP", "RELP", "WKHP"], inplace=True)


# ACSIncome_categories = {
#     "RAC1P": {
#         1.0: "White alone",
#         2.0: "Black or African American alone",
#         3.0: "American Indian alone",
#         4.0: "Alaska Native alone",
#         5.0: (
#             "American Indian and Alaska Native tribes specified;"
#             "or American Indian or Alaska Native,"
#             "not specified and no other"
#         ),
#         6.0: "Asian alone",
#         7.0: "Native Hawaiian and Other Pacific Islander alone",
#         8.0: "Some Other Race alone",
#         9.0: "Two or More Races",
#     }
# }
# for col, mapping in ACSIncome_categories.items():
#     if col in merged_df.columns:
#         merged_df[col] = merged_df[col].map(mapping).fillna("Unknown")
#         merged_df[col] = merged_df[col].astype("category")

#%%

merged_df


# %%
calibrated_classifier = LogisticRegression()
X_train, X_test, y_train, y_test, S_train, S_test = train_test_split(
    merged_df.values, y.values, f.values, test_size=0.5, random_state=0
)

X_train, X_cal, y_train, y_cal, S_train, S_cal = train_test_split(
    X_train, y_train, S_train, test_size=max(int(len(X_train) * 0.2),4000), random_state=0
)

calibrated_classifier.fit(S_cal.reshape(-1,1), y_cal)

c_hat_train = calibrated_classifier.predict_proba(S_train.reshape(-1,1))[:, 1]
c_hat_test = calibrated_classifier.predict_proba(S_test.reshape(-1,1))[:, 1]

residuals_train = y_train - c_hat_train
residuals_test = y_test - c_hat_test
dt = DecisionTreeRegressor(max_depth = 5, min_samples_leaf= 10)
dt.fit(X_train, residuals_train)
leaf_ids = dt.apply(X_test)


gle = glest.core.GLEstimatorResiduals(None, None)
gle.fit(X_test, y_test, y_scores_cal = c_hat_test, partition = leaf_ids)
# %%
grouping_diagram_residuals_descriptive(c_hat_test, gle.honest_tree_pred, gle.nj, S_test,y_cal, f_cal = S_cal,leaf_ids=leaf_ids, tree = dt, X_test = X_test, feature_names = merged_df.columns.tolist(), plot_cbar = False)
# %%

# Get the definition of groups created by the tree
def get_group_definitions(tree, X_test, leaf_ids, feature_names):
    """
    Extract the decision rules that define each group (leaf) in the tree
    """
    # Get unique leaf IDs
    unique_leaves = np.unique(leaf_ids)
    
    group_definitions = {}
    
    for leaf_id in unique_leaves:
        # Get samples in this leaf
        samples_in_leaf = X_test[leaf_ids == leaf_id]
        
        # Get the path to this leaf
        path = tree.decision_path(samples_in_leaf[:1]).toarray()[0]
        
        # Extract the rules
        raw_rules = []
        node_indicator = tree.decision_path(samples_in_leaf[:1])
        leaf_id_sample = tree.apply(samples_in_leaf[:1])[0]
        
        # Get the path from root to leaf
        feature = tree.tree_.feature
        threshold = tree.tree_.threshold
        
        for node_id in range(len(path)):
            if path[node_id] == 1:  # This node is in the path
                if feature[node_id] != -2:  # Not a leaf node
                    # Determine if we went left or right
                    sample_feature_value = samples_in_leaf[0, feature[node_id]]
                    feature_name = feature_names[feature[node_id]]
                    if sample_feature_value <= threshold[node_id]:
                        raw_rules.append((feature_name, "<=", threshold[node_id]))
                    else:
                        raw_rules.append((feature_name, ">", threshold[node_id]))
        
        # Combine rules for the same feature
        feature_bounds = {}
        for feature_name, operator, value in raw_rules:
            if feature_name not in feature_bounds:
                feature_bounds[feature_name] = {'min': None, 'max': None}
            
            if operator == "<=":
                if feature_bounds[feature_name]['max'] is None or value < feature_bounds[feature_name]['max']:
                    feature_bounds[feature_name]['max'] = value
            else:  # operator == ">"
                if feature_bounds[feature_name]['min'] is None or value > feature_bounds[feature_name]['min']:
                    feature_bounds[feature_name]['min'] = value
        
        # Convert bounds to readable rules
        combined_rules = []
        for feature_name, bounds in feature_bounds.items():
            if bounds['min'] is not None and bounds['max'] is not None:
                combined_rules.append(f"{bounds['min']:.3f} < {feature_name} <= {bounds['max']:.3f}")
            elif bounds['min'] is not None:
                combined_rules.append(f"{feature_name} > {bounds['min']:.3f}")
            elif bounds['max'] is not None:
                combined_rules.append(f"{feature_name} <= {bounds['max']:.3f}")
        
        group_definitions[leaf_id] = {
            'rules': combined_rules,
            'n_samples': len(samples_in_leaf),
            'sample_indices': np.where(leaf_ids == leaf_id)[0]
        }
    
    return group_definitions

# Get group definitions using actual feature names
group_defs = get_group_definitions(dt, X_test, leaf_ids, merged_df.columns.tolist())

# Print the definitions
print("Group Definitions:")
print("="*50)
for leaf_id, definition in group_defs.items():
    print(f"\nGroup {leaf_id} ({definition['n_samples']} samples):")
    if definition['rules']:
        for rule in definition['rules']:
            print(f"  - {rule}")
    else:
        print("  - Root node (no conditions)")
merged_df
# %%
