# Load Datasets
import pandas as pd
import numpy as np
import sys
import json
import warnings
from typing import List
import random
import os
warnings.filterwarnings('ignore')
from DRL.constraints_code.parser import parse_constraints_file
sys.path.append(os.path.abspath("."))  # add current dir
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle
from collections import Counter
import torch
from tqdm import tqdm
from lgmvae import get_cluster_centroids, reconstruct_from_centroids, LGMVAE
from ucimlrepo import fetch_ucirepo
from sklearn.utils import resample


def get_roundable_data(df):
    _is_roundable = ((df%1)==0).all(axis=0)
    roundable_cols = df.columns[_is_roundable]
    roundable_idx = [df.columns.get_loc(c) for c in roundable_cols]
    round_digits = df.iloc[:,roundable_idx].apply(get_round_decimals)
    return roundable_idx, round_digits

def get_round_decimals(col):
    MAX_DECIMALS = sys.float_info.dig - 1
    if (col == col.round(MAX_DECIMALS)).all():
        for decimal in range(MAX_DECIMALS + 1):
         if (col == col.round(decimal)).all():
             return decimal

def single_value_cols(df):
    a = df.to_numpy()
    single_value = (a[0] == a).all(0)
    return df.columns[single_value].to_list()

def read_csv(csv_filename, use_case="", manual_inspection_cat_cols_idx=[]):
    """Read a csv file."""
    data = pd.read_csv(csv_filename)
    single_val_col = single_value_cols(data)
    roundable_idx, round_digits = get_roundable_data(data)

    cat_cols_names = data.columns[manual_inspection_cat_cols_idx].values.tolist()
    for col in single_val_col:
        try:
            cat_cols_names.remove(col)
        except Exception as e:
            pass
    bin_cols_idx = [data.columns.get_loc(c) for c in cat_cols_names if c in data]
    roundable_idx = [i for i in roundable_idx if i not in bin_cols_idx]
    round_digits = round_digits[data.columns[roundable_idx]]

    if len(bin_cols_idx) == 0:
        bin_cols_idx = None
        cat_cols_names = None
    return data, (cat_cols_names, bin_cols_idx), (roundable_idx, round_digits)

def _load_json(path):
    with open(path) as json_file:
        return json.load(json_file)

def get_dataset(dname):
    if dname == 'heloc':
        info_path = "./dataset_config.json"
        dataset_info = _load_json(info_path)[dname]
        dpath = f"..."
        X_train, (cat_cols, cat_idx), (roundable_idx, round_digits) = read_csv(f"{dpath}/train_data.csv", dname, dataset_info["manual_inspection_categorical_cols_idx"])
        X_test = pd.read_csv(f"{dpath}/test_data.csv")
        X_val = pd.read_csv(f"{dpath}/val_data.csv")

        # --- 1. Load and Preprocess Data ---
        target_col = dataset_info['target_col']
        y_train = X_train[target_col].values
        X_train = X_train.drop(columns=[target_col]).values

        y_val = X_val[target_col].values
        X_val = X_val.drop(columns=[target_col]).values

        y_test = X_test[target_col].values
        X_test = X_test.drop(columns=[target_col]).values

        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_val = scaler.transform(X_val)
        X_test = scaler.transform(X_test)

        _, constraints_raw = parse_constraints_file(f"{dpath}/constraints.txt")
        constraints = []
        for i, item in enumerate(constraints_raw):
            constraints.append(item.readable())

        return X_train, y_train, X_val, y_val, X_test, y_test, np.array(constraints)
    elif dname == 'wine':
        wine_quality = fetch_ucirepo(id=186) 
        
        X = wine_quality.data.features
        y = wine_quality.data.targets
        y = y['quality'].apply(lambda value: 0 if value <= 5 else 1)

        scaler = StandardScaler()
        X = scaler.fit_transform(X)

        df_large = pd.DataFrame(X)
        df_large['target'] = y
        df_majority = df_large[df_large.target == 1]
        df_minority = df_large[df_large.target == 0]
        df_minority_upsampled = resample(
            df_minority, 
            replace=True,                  
            n_samples=int(len(df_minority) * 1.5), 
            random_state=42
        )
        df_train_upsampled = pd.concat([df_majority, df_minority_upsampled])
        df_train_upsampled = df_train_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)
        X_upsampled = df_train_upsampled.drop(columns='target').values
        y_upsampled = df_train_upsampled['target'].values

        X_train, X_temp, y_train, y_temp = train_test_split(
            X_upsampled, y_upsampled,
            test_size=0.3, 
            random_state=42, 
            stratify=y_upsampled,
        )

        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=0.5,
            random_state=42,
            stratify=y_temp
        )
        dpath = f"./{dname}"
        _, constraints_raw = parse_constraints_file(f"{dpath}/constraints.txt")
        constraints = []
        for i, item in enumerate(constraints_raw):
            constraints.append(item.readable())
        return X_train, y_train, X_val, y_val, X_test, y_test, np.array(constraints)
    elif dname == 'adult':
        df = pd.read_csv("./adult.csv")
        X = df.drop(columns='income').values
        y = df["income"].values
        print(df)

        df_large = pd.DataFrame(X)
        df_large['target'] = y
        df_majority = df_large[df_large.target == 0]
        df_minority = df_large[df_large.target == 1]
        df_minority_upsampled = resample(
            df_minority, 
            replace=True,                  
            n_samples=len(df_minority) * 3,
            random_state=42
        )
        df_train_upsampled = pd.concat([df_majority, df_minority_upsampled])
        df_train_upsampled = df_train_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)
        X_upsampled = df_train_upsampled.drop(columns='target').values
        y_upsampled = df_train_upsampled['target'].values

        X_train, X_temp, y_train, y_temp = train_test_split(
            X_upsampled, y_upsampled,
            test_size=0.3, 
            random_state=42, 
            stratify=y_upsampled,
        )

        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=0.5,
            random_state=42,
            stratify=y_temp
        )
        return X_train, y_train, X_val, y_val, X_test, y_test, None
    elif dname == 'compas':
        X = pd.read_csv("./data/compas_x.csv")
        y = pd.read_csv("./data/compas_y.csv")

        df_large = copy.deepcopy(X)
        df_large['target'] = y.values
        df_majority = df_large[df_large.target == 1]
        df_minority = df_large[df_large.target == 0]
        df_minority_upsampled = resample(
            df_minority, 
            replace=True,                  
            n_samples=len(df_minority) * 3,
            random_state=42
        )
        df_train_upsampled = pd.concat([df_majority, df_minority_upsampled])
        df_train_upsampled = df_train_upsampled.sample(frac=1, random_state=42).reset_index(drop=True)
        X_upsampled = df_train_upsampled.drop(columns='target').values
        y_upsampled = df_train_upsampled['target'].values

        X_train, X_temp, y_train, y_temp = train_test_split(
            X_upsampled, y_upsampled,
            test_size=0.3, 
            random_state=42, 
            stratify=y_upsampled,
        )

        X_val, X_test, y_val, y_test = train_test_split(
            X_temp, y_temp,
            test_size=0.5,
            random_state=42,
            stratify=y_temp
        )
        return X_train, y_train, X_val, y_val, X_test, y_test, None

def get_generative_model_config(mname, dname):
    with open('./generative_model_config.json', 'r') as file:
        configs = json.load(file)
    this_configs = configs[f"{dname}_{mname}"]
    model = LGMVAE(
        input_dim=this_configs["INPUT_DIM"], z_dim=this_configs["Z_DIM"], c_dim=this_configs["C_DIM"], y_dim=this_configs["Y_DIM"],
        beta_1=this_configs["BETA_1"], beta_2=this_configs["BETA_2"]
    ).to("cpu")
    model.load_state_dict(torch.load(this_configs["path"]))
    return model

def get_models(mname, X_train, y_train, X_val, y_val, X_test, y_test, n_est=25, max_depth=None, h_layers=(50, 30), dname='heloc'):
    np.random.seed(42)
    rf_classifier = None
    if dname == 'wine':
        n_est = 25
        max_depth = 5
    if mname == "rf":
        rf_classifier = RandomForestClassifier(n_estimators=n_est, random_state=42, n_jobs=-1, max_depth=max_depth)
    else:
        rf_classifier = MLPClassifier(
            hidden_layer_sizes=h_layers, 
            max_iter=500,
            random_state=1,
            early_stopping=True
        )
    # Train the model on the training data
    rf_classifier.fit(X_train, y_train)
    y_pred_train = rf_classifier.predict(X_train)

    acc_train = accuracy_score(y_train, y_pred_train)
    print(f"Accuracy on the training set: {acc_train:.4f}")

    y_pred_test = rf_classifier.predict(X_test)

    acc_test = accuracy_score(y_test, y_pred_test)
    f1_test = f1_score(y_test, y_pred_test, average='weighted')

    print(f"Test Accuracy: {acc_test:.4f}")
    print(f"Test F1 Score (Weighted): {f1_test:.4f}")

    print("\n--- Detailed Classification Report on Test Set ---")
    print(classification_report(y_test, y_pred_test))

    collected_classifiers = []
    test_accuracies = set()
    max_attempts = 50
    target_num_classifiers = 20

    for i in range(max_attempts):
        
        X_sub, _, y_sub, _ = train_test_split(
            X_train, y_train, train_size=0.95, stratify=y_train
        )

        n_estimators = random.randint(23, 27)
        random_state = i 
        new_classifier = None
        if mname == "rf":
            new_classifier = RandomForestClassifier(
                n_estimators=n_est,
                random_state=random_state,
                n_jobs=-1,
                max_depth=max_depth
            )
        else:
            new_classifier = MLPClassifier(
                hidden_layer_sizes=h_layers, 
                max_iter=500,
                random_state=i*500+2020,
                early_stopping=True
            )
        new_classifier.fit(X_sub, y_sub)

        # Evaluate on the test set
        y_pred = new_classifier.predict(X_val)
        acc = accuracy_score(y_val, y_pred)

        # Check for uniqueness and add to collection
        if acc not in test_accuracies:
            test_accuracies.add(acc)
            collected_classifiers.append(new_classifier)

        # Check for stopping condition
        if len(collected_classifiers) >= target_num_classifiers:
            break

    print(f"Total unique classifiers collected: {len(collected_classifiers)}")    
    return rf_classifier, collected_classifiers

def evaluate_generative_model_utility(
    gen_model, 
    clf, 
    X_train_new, 
    y_train_new, 
    X_test_new, 
    y_test_new,
    n_runs=10
):
    real_accuracies = []
    synthetic_accuracies = []
    prototype_accuracies = []

    for i in range(n_runs):

        rf_real = RandomForestClassifier(random_state=i, n_jobs=-1, n_estimators=5, max_depth=4)
        rf_real.fit(X_train_new, y_train_new)
        y_pred_real = rf_real.predict(X_test_new)
        acc_real = accuracy_score(y_test_new, y_pred_real)
        real_accuracies.append(acc_real)
        # Train on SYNTHETIC, Test on REAL
        class_counts = Counter(y_train_new)
        X_train_generated_list = []
        y_train_generated_list = []
        with torch.no_grad():
            for label, count in class_counts.items():
                generated_features = gen_model.sample(y_label=label, num_samples=count)
                X_train_generated_list.append(generated_features.cpu().numpy())
                y_train_generated_list.extend([label] * count)
        X_train_generated = np.concatenate(X_train_generated_list, axis=0)
        y_train_generated = np.array(y_train_generated_list)
        X_train_generated, y_train_generated = shuffle(X_train_generated, y_train_generated, random_state=i)

        rf_synthetic = RandomForestClassifier(random_state=i, n_jobs=-1, n_estimators=5, max_depth=4)
        rf_synthetic.fit(X_train_generated, y_train_generated)
        y_pred_synthetic = rf_synthetic.predict(X_test_new)
        acc_synthetic = accuracy_score(y_test_new, y_pred_synthetic)
        synthetic_accuracies.append(acc_synthetic)

        # Test original classifier on Centroid Prototypes
        all_centroids = get_cluster_centroids(gen_model)
        X_prototypes = reconstruct_from_centroids(gen_model, all_centroids)
        clusters_per_class = gen_model.c_dim // gen_model.y_dim
        y_prototypes_ground_truth = np.repeat(np.arange(gen_model.y_dim), repeats=clusters_per_class)
        y_pred_prototypes = clf.predict(X_prototypes)
        acc_prototypes = accuracy_score(y_prototypes_ground_truth, y_pred_prototypes)
        prototype_accuracies.append(acc_prototypes)

    print("\n" + "="*50)
    print(f"--- GENERATIVE MODEL UTILITY SUMMARY (after {n_runs} runs) ---")
    print("="*50)

    mean_real = np.mean(real_accuracies)
    std_real = np.std(real_accuracies)
    print(f"Train on REAL -> Test on REAL:         Accuracy = {mean_real:.4f} ± {std_real:.4f}")

    mean_synthetic = np.mean(synthetic_accuracies)
    std_synthetic = np.std(synthetic_accuracies)
    print(f"Train on SYNTHETIC -> Test on REAL:    Accuracy = {mean_synthetic:.4f} ± {std_synthetic:.4f}")

    mean_prototypes = np.mean(prototype_accuracies)
    std_prototypes = np.std(prototype_accuracies)
    print(f"Original Classifier on CENTROIDS:      Accuracy = {mean_prototypes:.4f} ± {std_prototypes:.4f}")
    print("="*50)
    all_centroids = get_cluster_centroids(gen_model)
    results = {
        "real_on_real": (mean_real, std_real),
        "synthetic_on_real": (mean_synthetic, std_synthetic),
        "clf_on_prototypes": (mean_prototypes, std_prototypes)
    }
    return all_centroids, results
