
from gen_synth import *
from run_method import run_method_single_subgroup, run_method_multiple_times
import numpy as np
from sklearn.preprocessing import StandardScaler
import argparse
import torch
from syflow import syflow, And_Finder
from causalml.inference.tree import CausalTreeRegressor # Import from causalml
import pysubgroup as ps
import pandas as pd
from utils import *
import timeit 
import itertools
from configs import *
from econml.grf import CausalForest
from sklearn.metrics import f1_score

def parse_args():
    parser = argparse.ArgumentParser(description="Run synthetic data experiments.")
    parser.add_argument("--method", type=str, required=True, help="Method to use for subgroup discovery.")
    parser.add_argument("--setting", type=str, default="observational", choices=["observational", "demographic", "interventional", "mediator"], help="Setting for the synthetic data generation.")
    parser.add_argument("--n", type=int, default=2000, help="Number of samples.")
    parser.add_argument("--d", type=int, default=5, help="Number of features.")
    parser.add_argument("--tau", type=float, default=4.0, help="Treatment effect within the subgroup.")
    parser.add_argument("--gamma", type=float, default=1, help="Treatment effect outside the subgroup.")
    parser.add_argument("--c", type=float, default=2.0, help="Confounding effect within the subgroup.")
    parser.add_argument("--sigma", type=float, default=0.5, help="Standard deviation of the noise in Y.")
    parser.add_argument("--sg_size", type=float, default=0.25, help="Size of the subgroup.")
    parser.add_argument("--rule_size", type=int, default=2, help="Number of features in the rule.")
    parser.add_argument("--n_subgroups", type=int, default=3, help="Number of subgroups to generate.")
    parser.add_argument("--mean_shift",type=bool, default=True, help="Whether to apply mean shift in treatment assignment.")
    parser.add_argument("--config_id", type=int, default=-1, help="Configuration ID for hyperparameter optimization.")
    return parser.parse_args()

args = parse_args()
method = args.method
n = args.n
d = args.d
tau = args.tau
gamma = args.gamma
c = args.c
sigma = args.sigma
sg_size = args.sg_size
rule_size = args.rule_size
n_subgroups = args.n_subgroups
mean_shift = args.mean_shift
parsed_config_id = args.config_id
setting = args.setting

f1s = []
accuracies = []
purities = []
runtimes = []
precisions = []
recalls = []


# enumerate all combinations of hyperparameters
combinations = []
if method == "subcon":
    combinations = Subcon_Config().get_all_configs()
elif method == "syflow":
    combinations = Syflow_Config().get_all_configs()
elif method == "pysubgroup":
    combinations = PySubgroup_Config().get_all_configs()
elif method == "causaltree":
    combinations = CausalTree_Config().get_all_configs()
elif method == "honesttree":
    combinations = HonestTree_Config().get_all_configs()
if parsed_config_id != -1:
    resultfile = open(f"hyperparameters/{method}-{setting}.csv", "a")
else:
    resultfile = open(f"hyperparameters/{method}-{setting}.csv", "w")
    resultfile.write("ConfigID;n;d;tau;f1;f1_std;accuracy;accuracy_std;purity;purity_std;runtime;runtime_std;precision;precision_std;recall;recall_std\n")
config_id = 0
best_config = None
best_f1 = -1
for params in combinations:
    if parsed_config_id != -1:
        config_id = parsed_config_id
        params = combinations[config_id]
    for seed in np.random.randint(0, 100000, size=10):
        
        if setting == "observational":
            data = gen_observational_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
        elif setting == "demographic":
            data = gen_demographic_data(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
        elif setting == "interventional":
            data = gen_interventional_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
        else:
            raise ValueError(f"Unknown setting: {setting}. Choose 'observational' or 'demographic'.")

        X, s_star, A, Y = data["X"], data["s_star"], data["A"], data["Y"]
        is_discrete = [False for _ in range(X.shape[1])]  # Assuming all features are discrete for this example
        scaler_X = StandardScaler()
        X = scaler_X.fit_transform(X)
        scaler_Y = StandardScaler()
        Y = scaler_Y.fit_transform(Y.reshape(-1, 1)).flatten()
        feature_names = [f"X{i}" for i in range(X.shape[1])]

        t = timeit.default_timer()

        if method == "subcon":
            our_config = Subcon_Config().get_setting_config(setting)
            subgroups, rules, models = run_method_multiple_times(X, X[A == 0], X[A == 1], Y, Y[A == 0], Y[A == 1], scaler_X, scaler_Y, feature_names, is_discrete, our_config, discrete_target=False, maximize=True, plot=False, max_reps=1)
            best_sg = -1
            best_overlap = -1
            masks = []
            for i, (s0_mask, s1_mask) in enumerate(subgroups):
                candidate_sg = np.zeros(X.shape[0], dtype=int)
                A0_indices = np.where(A == 0)[0]
                A1_indices = np.where(A == 1)[0]
                candidate_sg[A0_indices[s0_mask]] = 1
                candidate_sg[A1_indices[s1_mask]] = 1
                overlap = f1_score(s_star, candidate_sg)
                masks.append(candidate_sg)
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_sg = i
            sg = masks[best_sg]

        elif method == "syflow":
            X_cat = np.concatenate([X, A.reshape(-1, 1)], axis=1)
            X_cat = torch.tensor(X_cat, dtype=torch.float64)
            limits = get_data_limits(X_cat)
            model = And_Finder(limits)
            Y = Y.squeeze()
            Y = torch.tensor(Y[:,None], dtype=torch.float64)
            flow_population = None
            flow_subgroup = None
            subgroups = []
            subgroup_priors = []
            for _ in range(n_subgroups):
                (flow_population, flow_subgroup), classifier, sg = syflow(X_cat, Y, model,alpha=params["alpha"], lr_classifier=params["lr_classifier"], subgroup_train_epochs=params["subgroup_train_epochs"], flow_population=flow_population, subgroup_priors=subgroup_priors, progressbar=False)
                sg = sg.numpy()
                subgroups.append(sg)
                subgroup_priors.append(flow_subgroup)
            best_sg = -1
            best_overlap = -1
            for i, candidate_sg in enumerate(subgroups):
                overlap = f1_score(s_star, candidate_sg)
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_sg = i
            sg = subgroups[best_sg]

        elif method == "pysubgroup":
            X = pd.DataFrame(X)
            Y = pd.DataFrame(Y)
            Y.columns = ["Y"]
            X.columns = [f"X{i}" for i in range(X.shape[1])]
            data = pd.concat([X,Y],axis=1)
            target = ps.NumericTarget("Y")
            search_space = ps.create_selectors(data, ignore=["Y"],nbins=params["n_bins"], intervals_only=False)
            task = ps.SubgroupDiscoveryTask (
                data, 
                target, 
                search_space, 
                result_set_size=n_subgroups, 
                depth=10 ,
                qf=ps.StandardQFNumeric(params["alpha"]))

            result = ps.BeamSearch(beam_width=params["beam_width"],beam_width_adaptive=False).execute(task)
            result.to_dataframe()
            subgroups = []
            rules = []
            for i in range(n_subgroups):
                result_string = str(result.to_dataframe().iloc[i]["subgroup"])
                rules.append(replace_feature_names(result_string,feature_names))
                parts = result_string.split(" AND ")
                conditions = []
                for part in parts:
                    # parse this: "X0>=0.80" or "X0<0.80"
                    if "==" in part:
                        var, val = part.split("==")
                        var = int(var[1:])
                        #if isinstance(val,str):
                        val = convert(data,var, val)
                        conditions.append((var,val,val))
                        continue
                    elif "<" in part:
                        var, high = part.split("<")
                        low = - np.inf
                        var = int(var[1:])
                        high = float(high)
                    else:
                        var, low = part.split(">=")
                        high = np.inf
                        var = int(var[1:])
                        low = float(low)
                    conditions.append((var,low,high))
                    
                subgroup_member = np.ones((X.shape[0],),dtype=bool)
                for cond in conditions:
                    var, low, high = cond
                    var = int(var)
                    subgroup_member = np.logical_and(subgroup_member, np.logical_and(X.iloc[:,var]>=low, X.iloc[:,var]<=high))
                subgroups.append(subgroup_member)
            best_sg = -1
            best_overlap = -1
            for i, candidate_sg in enumerate(subgroups):
                overlap = f1_score(s_star, candidate_sg)
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_sg = i
            sg = subgroups[best_sg]
        elif method == "causaltree":
            min_support = int(params["min_samples_leaf"] * X.shape[0])
            model = CausalTreeRegressor(min_samples_leaf=min_support, max_depth=params["max_depth"])
            
            # Fit the model using features (X), treatment (A), and outcome (Y).
            model.fit(X, treatment=A, y=Y)

            sg = model.apply(X)
            # take label with highest overlap

            best_leaf = -1
            best_overlap = -1
            for leaf in np.unique(sg):
                overlap = f1_score(s_star, (sg == leaf).squeeze())
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_leaf = leaf
            sg = (sg == best_leaf)

        elif method == "honesttree":
            config = HonestTree_Config().get_setting_config(setting)
            # train a tree, i.e. 1 estimator with max_samples=1
            min_support = int(config["min_samples_leaf"] * X.shape[0])
            max_depth = config["max_depth"]
            model = CausalForest(n_estimators=1,max_samples=1., honest=True,min_samples_leaf=min_support, max_depth=max_depth,subforest_size=1,inference=False)
            # Fit the model using features (X), treatment (A), and outcome (Y).
            model.fit(X, A, Y)
            sg = model.apply(X)
            # take label with highest overlap
            best_leaf = -1
            best_overlap = -1
            for leaf in np.unique(sg):
                leaf_sg = (sg == leaf).squeeze()
                overlap = f1_score(s_star, leaf_sg)
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_leaf = leaf
            sg = (sg == best_leaf).squeeze()

        t = timeit.default_timer() - t
        runtimes.append(t)
        fp = np.sum((sg == 1) & (s_star == 0))
        fn = np.sum((sg == 0) & (s_star == 1))
        tp = np.sum((sg == 1) & (s_star == 1))
        tn = np.sum((sg == 0) & (s_star == 0))
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * tp / (2 * tp + fp + fn)
        accuracy = (tp + tn) / (tp + fp + tn + fn)
        purity = tp / (tp + fp) if (tp + fp) > 0 else 0
        f1s.append(f1)
        accuracies.append(accuracy)
        purities.append(purity)
        precisions.append(precision)
        recalls.append(recall)

        print(f"Config {config_id}: F1: {f1}, Precision: {precision}, Recall: {recall}, Accuracy: {accuracy}, Purity: {purity}, Runtime: {t}")
        

    f1 = np.mean(f1s)
    accuracy = np.mean(accuracies)
    purity = np.mean(purities)
    runtime = np.mean(runtimes)
    precision = np.mean(precisions)
    recall = np.mean(recalls)

    f1_std = np.std(f1s)
    accuracy_std = np.std(accuracies)
    purity_std = np.std(purities)
    runtime_std = np.std(runtimes)
    precision_std = np.std(precisions)
    recall_std = np.std(recalls)

    

    resultfile.write(f"Config{config_id};{n};{d};{tau};{f1};{f1_std};{accuracy};{accuracy_std};{purity};{purity_std};{runtime};{runtime_std};{precision};{precision_std};{recall};{recall_std}\n")
    config_id += 1
    if f1 > best_f1:
        best_f1 = f1
        best_config = params
    if parsed_config_id != -1:
        break
print(f"Best Config: {best_config} with F1: {best_f1} for method {method} and setting {setting}")

resultfile.close()
if parsed_config_id != -1:
    try :
        best_config_file = open(f"hyperparameters/best_params/{method}-{setting}.txt", "r")
        first_line = best_config_file.readline().strip()
        f1 = float(first_line.split(":")[-1])
        if f1 < best_f1:
            best_config_file.close()
            best_config_file = open(f"hyperparameters/best_params/{method}-{setting}.txt", "w")
            best_config_file.write(f"Best Config for {method} with F1: {best_f1}\n")
            best_config_file.write(f"{best_config}\n")
            best_config_file.close()
        else:
            best_config_file.close()
    except FileNotFoundError:
        best_config_file = open(f"hyperparameters/best_params/{method}-{setting}.txt", "w")
        best_config_file.write(f"Best Config for {method} with F1: {best_f1}\n")
        best_config_file.write(f"{best_config}\n")
        best_config_file.close()
    