import numpy as np
import pandas as pd
from itertools import product
from scipy.special import kl_div
from rset_wrapper import RsetWrapper # Assuming this is available
import json
import os

# --- Configuration ---
DIMENSIONS = 4
N_SAMPLES_PER_TRIAL = 100  # Samples used for training in each trial
N_TRIALS = 5               # Number of trials for each distribution
K_ENSEMBLE_SAMPLES = 20    # Number of ensembles to sample for averaging KL for a given m
DELTA_CLIP = 1e-6
OUTPUT_FILENAME = "experiment_results.json"

TF_CONFIG = {
    "regularization": 0.001,
    "rashomon_bound_multiplier": 0.03,
    # "max_depth": 3, # Optional: to control tree complexity and N
}

# --- Helper Functions ---
def get_all_binary_inputs(dim):
    return np.array(list(product([0, 1], repeat=dim)))

ALL_X_VECTORS_NP = get_all_binary_inputs(DIMENSIONS)
N_UNIQUE_X = ALL_X_VECTORS_NP.shape[0]
FEATURE_NAMES = [f'x{i}' for i in range(DIMENSIONS)]
ALL_X_VECTORS_DF = pd.DataFrame(ALL_X_VECTORS_NP, columns=FEATURE_NAMES)


def sample_data_from_true_dist(all_x_vectors_np, p_y1_map, n_samples):
    X_train_list = []
    Y_train_list = []
    num_unique_x_configs = all_x_vectors_np.shape[0]
    for _ in range(n_samples):
        idx = np.random.randint(num_unique_x_configs)
        x_vec_tuple = tuple(all_x_vectors_np[idx])
        prob_y1 = p_y1_map[x_vec_tuple]
        y_val = np.random.binomial(1, prob_y1)
        X_train_list.append(list(x_vec_tuple))
        Y_train_list.append(y_val)
    return np.array(X_train_list), np.array(Y_train_list)

def define_true_distribution(dist_id, all_x_vectors_np, delta_clip):
    true_p_y1_map = {}
    np.random.seed(dist_id) 

    for i in range(all_x_vectors_np.shape[0]):
        x_vec = tuple(all_x_vectors_np[i])
        prob_y1 = 0.5 

        if dist_id == 0: 
            prob_y1 = 0.1 + 0.8 * (np.sum(x_vec) % 2 == 0)
        elif dist_id == 1: 
            prob_y1 = 0.15 + 0.7 * (x_vec[0] == 1 and x_vec[1] == 1)
        elif dist_id == 2: 
            prob_y1 = 0.2 + 0.6 * ((x_vec[0] ^ x_vec[1]) == 1 and (x_vec[2] ^ x_vec[3]) == 0)
        elif dist_id == 3: 
            prob_y1 = 0.3 + 0.4 * np.random.rand() 
        elif dist_id == 4: 
            if np.sum(x_vec) <= 1: prob_y1 = 0.05
            elif np.sum(x_vec) >= 3: prob_y1 = 0.95
            else: prob_y1 = 0.5 + (x_vec[0] - 0.5) * 0.4

        prob_y1 += np.random.uniform(-0.05, 0.05) 
        true_p_y1_map[x_vec] = np.clip(prob_y1, delta_clip, 1.0 - delta_clip)
    return true_p_y1_map

# --- Main Experimental Loop ---
experiment_data_to_save = {
    "experiment_params": {
        "DIMENSIONS": DIMENSIONS,
        "N_SAMPLES_PER_TRIAL": N_SAMPLES_PER_TRIAL,
        "N_TRIALS": N_TRIALS,
        "K_ENSEMBLE_SAMPLES": K_ENSEMBLE_SAMPLES,
        "DELTA_CLIP": DELTA_CLIP,
        "TF_CONFIG": TF_CONFIG
    },
    "distribution_results": []
}

N_TOTAL_SAMPLES_PER_DIST = N_SAMPLES_PER_TRIAL * N_TRIALS

for dist_idx in range(5): 
    print(f"\n--- Processing Distribution {dist_idx + 1}/5 ---")
    current_true_p_y1_map = define_true_distribution(dist_idx, ALL_X_VECTORS_NP, DELTA_CLIP)

    X_full_dist_np, Y_full_dist_np = sample_data_from_true_dist(
        ALL_X_VECTORS_NP, current_true_p_y1_map, N_TOTAL_SAMPLES_PER_DIST
    )

    all_m_kl_pairs_for_this_dist = [] 
    trial_avg_accuracies_for_dist = [] 

    for trial_idx in range(N_TRIALS):
        print(f"  Distribution {dist_idx + 1}, Trial {trial_idx + 1}/{N_TRIALS}:")

        start_idx = trial_idx * N_SAMPLES_PER_TRIAL
        end_idx = (trial_idx + 1) * N_SAMPLES_PER_TRIAL
        X_train_np_trial = X_full_dist_np[start_idx:end_idx]
        Y_train_np_trial = Y_full_dist_np[start_idx:end_idx]
        
        X_train_df_trial = pd.DataFrame(X_train_np_trial, columns=FEATURE_NAMES)
        Y_train_series_trial = pd.Series(Y_train_np_trial, name='target')

        print(f"    Fitting TreeFARMS for Dist {dist_idx + 1}, Trial {trial_idx + 1} (on {N_SAMPLES_PER_TRIAL} samples)...")
        tf_model = RsetWrapper(TF_CONFIG)
        try:
            tf_model.fit(X_train_df_trial, Y_train_series_trial)
        except Exception as e:
            print(f"    Error fitting TreeFARMS for Dist {dist_idx + 1}, Trial {trial_idx + 1}: {e}. Skipping this trial.")
            continue 

        N_RASHOMON_MODELS_trial = tf_model.ntrees()

        if N_RASHOMON_MODELS_trial == 0:
            print(f"    TreeFARMS found 0 models for Dist {dist_idx + 1}, Trial {trial_idx + 1}. Skipping this trial's KL & Acc calculation.")
            continue 
        
        print(f"    TreeFARMS found {N_RASHOMON_MODELS_trial} models for this trial.")

        trial_rashomon_set_p_y1_predictions = []
        model_accuracies_this_trial = []

        for i_model in range(N_RASHOMON_MODELS_trial):
            tree_output_obj = tf_model.get_tree(i_model)
            # Assuming generate_proba might be needed if score relies on it, or for internal state.
            # If not strictly needed for .score or .predict_proba on ALL_X_VECTORS_NP, could be optimized.
            tree_output_obj.generate_proba(X_train_np_trial, Y_train_np_trial) 
            
            try:
                accuracy = tree_output_obj.score(X_train_np_trial, Y_train_np_trial)
                model_accuracies_this_trial.append(accuracy)
            except Exception as e:
                print(f"      Error scoring model {i_model} (Dist {dist_idx+1}, Trial {trial_idx+1}): {e}. Skipping accuracy for this model.")
                # Potentially skip this model for KL too if state is inconsistent
                # For now, we assume predict_proba might still work or is independent

            model_preds_map_for_tree_i = {}
            try:
                proba_for_all_x = tree_output_obj.predict_proba(ALL_X_VECTORS_NP)
                for x_idx_map in range(N_UNIQUE_X):
                    x_vec_tuple = tuple(ALL_X_VECTORS_NP[x_idx_map])
                    p_y1 = proba_for_all_x[x_idx_map, 1]
                    model_preds_map_for_tree_i[x_vec_tuple] = np.clip(p_y1, DELTA_CLIP, 1.0 - DELTA_CLIP)
            except Exception as e:
                print(f"      Error getting predictions for model {i_model} (Dist {dist_idx+1}, Trial {trial_idx+1}): {e}. Skipping this tree for KL.")
                # To prevent adding incomplete data or crashing, skip this tree's predictions
                continue # Skip to next model if predictions fail
            trial_rashomon_set_p_y1_predictions.append(model_preds_map_for_tree_i)

        if not trial_rashomon_set_p_y1_predictions: # All models failed to give predictions
            print(f"    No valid model predictions obtained in Trial {trial_idx + 1} for Dist {dist_idx + 1}. Skipping KL for this trial.")
            continue
        
        # Update N_RASHOMON_MODELS_trial if some models were skipped for prediction
        N_RASHOMON_MODELS_trial = len(trial_rashomon_set_p_y1_predictions)
        if N_RASHOMON_MODELS_trial == 0:
            print(f"    After filtering, 0 models with valid predictions in Trial {trial_idx + 1} for Dist {dist_idx + 1}. Skipping KL for this trial.")
            continue

        if model_accuracies_this_trial:
            avg_trial_accuracy = np.mean(model_accuracies_this_trial)
            trial_avg_accuracies_for_dist.append(avg_trial_accuracy)
            print(f"    Average accuracy for this trial's Rashomon set: {avg_trial_accuracy:.4f}")
        else:
            print(f"    No model accuracies recorded for this trial.")

        m_values_this_trial_range = list(range(1, N_RASHOMON_MODELS_trial + 1))
        
        for m_val in m_values_this_trial_range:
            kl_samples_for_this_m_this_trial = []
            for _ in range(K_ENSEMBLE_SAMPLES): 
                indices = np.random.choice(N_RASHOMON_MODELS_trial, size=m_val, replace=False)
                current_ensemble_model_preds_maps = [trial_rashomon_set_p_y1_predictions[i] for i in indices]

                kl_for_all_x_this_ensemble_sample = []
                for x_idx_map in range(N_UNIQUE_X):
                    x_vec = tuple(ALL_X_VECTORS_NP[x_idx_map])
                    p_true_y1 = current_true_p_y1_map[x_vec]

                    q_pi_y1_sum = 0
                    for model_p_map in current_ensemble_model_preds_maps:
                        q_pi_y1_sum += model_p_map[x_vec]
                    q_pi_y1_ensemble_avg = q_pi_y1_sum / m_val
                    q_pi_y1_ensemble_avg = np.clip(q_pi_y1_ensemble_avg, DELTA_CLIP, 1.0 - DELTA_CLIP)

                    p_dist = np.array([1.0 - p_true_y1, p_true_y1])
                    q_dist = np.array([1.0 - q_pi_y1_ensemble_avg, q_pi_y1_ensemble_avg])
                    kl_val = np.sum(kl_div(p_dist, q_dist))
                    kl_for_all_x_this_ensemble_sample.append(kl_val)
                
                kl_samples_for_this_m_this_trial.append(np.mean(kl_for_all_x_this_ensemble_sample))
            
            avg_kl_for_current_m_trial = np.mean(kl_samples_for_this_m_this_trial)
            all_m_kl_pairs_for_this_dist.append((m_val, avg_kl_for_current_m_trial))
            
            if N_RASHOMON_MODELS_trial > 0 and (m_val % max(1, (N_RASHOMON_MODELS_trial // 5)) == 0 or m_val == 1 or m_val == N_RASHOMON_MODELS_trial):
                 print(f"      Dist {dist_idx+1}, Trial {trial_idx+1}: m = {m_val}, Avg KL (E_pi E_x [KL]) = {avg_kl_for_current_m_trial:.4f}")

    # --- Data aggregation for this distribution ---
    accuracy_summary_data = {
        "mean_training_accuracy": None,
        "std_dev_training_accuracy": None,
        "num_successful_trials_for_accuracy": 0
    }
    if trial_avg_accuracies_for_dist:
        mean_acc = np.mean(trial_avg_accuracies_for_dist)
        std_acc = np.std(trial_avg_accuracies_for_dist)
        accuracy_summary_data["mean_training_accuracy"] = mean_acc
        accuracy_summary_data["std_dev_training_accuracy"] = std_acc
        accuracy_summary_data["num_successful_trials_for_accuracy"] = len(trial_avg_accuracies_for_dist)
        print(f"\n  Accuracy Summary for Distribution {dist_idx + 1}:")
        print(f"    Mean training accuracy across {len(trial_avg_accuracies_for_dist)} successful trials: {mean_acc:.4f}")
        print(f"    Std Dev of training accuracy across trials: {std_acc:.4f}")
    else:
        print(f"\n  No trial accuracies recorded for Distribution {dist_idx + 1} to summarize.")

    if not all_m_kl_pairs_for_this_dist:
        print(f"  No KL results from any trial for Distribution {dist_idx + 1}. Skipping this distribution in saved results.")
        # Append a placeholder or skip, here we'll save it with empty KL data
        dist_data_to_save = {
            "label": f'Dist {dist_idx+1}',
            "m_values": [],
            "mean_kl_values": [],
            "std_kl_values": [],
            "accuracy_summary": accuracy_summary_data
        }
        experiment_data_to_save["distribution_results"].append(dist_data_to_save)
        continue

    df_m_kl_dist = pd.DataFrame(all_m_kl_pairs_for_this_dist, columns=['m', 'kl'])
    mean_kl_by_m = df_m_kl_dist.groupby('m')['kl'].mean()
    std_kl_by_m = df_m_kl_dist.groupby('m')['kl'].std().fillna(0) 

    m_values_plot = mean_kl_by_m.index.tolist()
    mean_kl_plot = mean_kl_by_m.values.tolist()
    std_kl_plot = std_kl_by_m.values.tolist()
    
    if not m_values_plot: 
        print(f"  No aggregated m values for Distribution {dist_idx + 1} after processing trials. Skipping.")
        dist_data_to_save = {
            "label": f'Dist {dist_idx+1}',
            "m_values": [],
            "mean_kl_values": [],
            "std_kl_values": [],
            "accuracy_summary": accuracy_summary_data
        }
    else:
        print(f"  Distribution {dist_idx + 1} KL Plot Summary: Processed {len(m_values_plot)} unique m values across successful trials for KL plot.")
        dist_data_to_save = {
            "label": f'Dist {dist_idx+1}',
            "m_values": m_values_plot,
            "mean_kl_values": mean_kl_plot,
            "std_kl_values": std_kl_plot,
            "accuracy_summary": accuracy_summary_data
        }
    experiment_data_to_save["distribution_results"].append(dist_data_to_save)

# --- Saving All Results ---
try:
    with open(OUTPUT_FILENAME, 'w') as f:
        json.dump(experiment_data_to_save, f, indent=4, allow_nan=True) # allow_nan for robustness if NaNs appear
    print(f"\nExperiment results saved to {OUTPUT_FILENAME}")
except IOError as e:
    print(f"Error saving results to {OUTPUT_FILENAME}: {e}")
except TypeError as e:
    print(f"Error serializing data to JSON: {e}. Check for non-serializable types.")

print("\n--- Experiment Run Finished ---")
