###Libraries imports
import pandas as pd
import numpy as np
import sklearn
import json
import argparse
import time
from sklearn.ensemble import RandomForestClassifier
from verification.toolbox import KantchModel

from datasets_infos import datasets_ohe_vectors, predictions
from utils import * 

parser = argparse.ArgumentParser(description='Dataset reconstruction from random forest')
parser.add_argument('--expe_id', type=int, default=0)
args = parser.parse_args()
expe_id=args.expe_id

# Manually set parameters
debug = True
n_threads = -1 # -1 = all threads available
print_logs = 1
rank = expe_id

time_out =  5*60*60
n_random_sols=100
datasets = ["adult", "bank-marketing", "compas", "default_credit", "diabetes","fico"]
sample_size = 1500
train_size = 100
method = "cp-sat"
model = "rset"

if method in ['cp-sat', 'milp', 'bench']:
    val_trees = list(range(1, 50))
    val_trees.extend([50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 120, 125, 130, 135, 140, 145, 150, 175, 200])
    val_depths = [4]
    val_seed = [i for i in range(5)]

    rule_list = ["closest", "random", "farthest", "increment", "densest", "sparsest"]

    params_list = []
    for rule in rule_list:
        for t in val_trees:
            for d in val_depths:
                for s in val_seed:
                    for dd in datasets:
                        params_list.append([rule, t, d, s, dd])

bagging = False
rule = params_list[rank][0]
n_estimators = params_list[rank][1]
max_depth_t =  params_list[rank][2]
seed = params_list[rank][3]
dataset = params_list[rank][4]

np.random.seed(seed)

if debug:
    print("#params: ", len(params_list))
    print("Local Config: ")
    print(" dataset = ", dataset)
    print(" method = ", method)
    print(" model = ", model)
    print(" rule = ", rule)
    print(" n_estimators = ", n_estimators)
    print(" max_depth_t = ", max_depth_t)
    print(" seed = ", seed)
    print("--------------------", flush=True)

# Prepare data
ohe_vector = datasets_ohe_vectors[dataset]

df = pd.read_csv("data/%s.csv" %dataset)
df = df.sample(n=sample_size, random_state = seed, ignore_index=True)

X_train, X_test, y_train, y_test = data_splitting(df, predictions[dataset], sample_size - train_size, seed)
X_train = X_train.to_numpy()
X_test = X_test.to_numpy()
y_train = y_train.to_numpy()
y_test = y_test.to_numpy()
if debug:
    print("Using dataset %s, training set size is %d with %d attributes." %(dataset, train_size, X_test.shape[1]))
    checked_ohe = check_ohe(X_train, datasets_ohe_vectors[dataset])
    print("OHE verified: ", checked_ohe)
    if not checked_ohe:
        exit()

from rset_wrapper import RsetWrapper
from rf_wrapper import RFWrapper
eps = 0.02
lamb = 0.01
if dataset == "fico":
    lamb = 0.02
if dataset == "default_credit":
    lamb = 0.0125
if dataset == "bank-marketing":
    lamb = 0.013
if dataset == "diabetes":
    lamb = 0.0165
config = {
    "regularization": lamb,
    "depth_budget": max_depth_t + 1,
    "rashomon_bound_adder": eps,
}
rset_clf = RsetWrapper(config)
rset_clf.fit(X_train, y_train)
rng = np.random.default_rng(seed)
rset_clf.find_special_tree(rng)
clf = RFWrapper(rset_clf, n_estimators, rule, rng, X_train, y_train)


import matplotlib.pyplot as plt
import numpy as np

accuracy_train = clf.score(X_train, y_train)
accuracy_test = clf.score(X_test, y_test)
if debug:
    print("accuracy_train=", accuracy_train, "accuracy_test=",accuracy_test)

# Perform the reconstruction
from DRAFT import DRAFT

extractor = DRAFT(clf, datasets_ohe_vectors[dataset])
dict_res = extractor.fit(bagging=bagging, method=method, timeout=time_out, verbosity=False, n_jobs=n_threads, seed=seed) # 'status':solve_status, 'duration': duration, 'reconstructed_data':x_sol

optimal_model = clf.estimators_[0]
model_json = optimal_model.tree_.to_json()
kantch_model = KantchModel(model_json, 2)
X_adv = kantch_model.adversarial_examples(X_test, y_test, order=0, options={"disable_progress_bar": True, "epsilon": 1})

accuracies = [clf.estimators_[i].score(X_adv, y_test) for i in range(n_estimators)]
best_valid_accuracy = max(accuracies)
median_valid_accuracy = np.median(accuracies)
mean_valid_accuracy = np.mean(accuracies)


solve_status = dict_res['status']
duration = dict_res['duration']
x_sol = dict_res['reconstructed_data']
max_max_depth = dict_res['max_max_depth']

# Evaluate the reconstruction
x_train_list = X_train.tolist() # ground truth

e_mean, list_matching = average_error(x_sol,x_train_list)


moyenne_rand = 0
liste_random_sol = generate_random_sols(len(x_train_list), X_test.shape[1], dataset_ohe_groups=datasets_ohe_vectors[dataset], n_sols=n_random_sols, seed=seed)
for e in liste_random_sol:
    moyenne_rand += average_error(e, x_train_list)[0]
moyenne_rand = moyenne_rand/len(liste_random_sol)

if solve_status == 'UNKNOWN':
    e_mean = moyenne_rand 

sol_dict = {}
sol_dict["values"] = {"n_trees": n_estimators,
                    "max_depth": max_depth_t,
                    "real_max_depth": max_max_depth,
                    "seed": seed,
                    "accuracy test": accuracy_test,
                    "accuracy train": accuracy_train,
                    "solve_status": solve_status,
                    "mean-error": e_mean,
                    "random_error": moyenne_rand,
                    "solve_duration_time": duration,
                    "best_adv_accuracy": best_valid_accuracy,
                    "median_adv_accuracy": median_valid_accuracy,
                    "mean_adv_accuracy": mean_valid_accuracy,}

filename = str(dataset) + "_" + str(n_estimators) +"_"+ str(max_depth_t)+ "_" + str(seed) + "_" + str(rule)

import os
os.makedirs("./results", exist_ok=True)

with open(f"./results/{filename}.json", 'w') as f:
    json.dump(sol_dict, f, indent =4)

if debug:
    print("Complete solving duration :", duration)
    print("Reconstruction Error: ", e_mean)
    print("Baseline (Random) Error: ", moyenne_rand)
    print("Best adversarial accuracy: ", best_valid_accuracy)
