from gen_synth import *
from run_method import run_method_single_subgroup, run_method_multiple_times
import numpy as np
from sklearn.preprocessing import StandardScaler
import argparse
import torch
from syflow import syflow, And_Finder
from causalml.inference.tree import CausalTreeRegressor # Import from causalml
import pysubgroup as ps
import pandas as pd
from utils import *
import timeit 
from configs import *
from econml.grf import CausalForest
from sklearn.metrics import f1_score

def parse_args():
    parser = argparse.ArgumentParser(description="Run synthetic data experiments.")
    parser.add_argument("--method", type=str, required=True, help="Method to use for subgroup discovery.")
    parser.add_argument("--outpath", type=str, required=True, help="Output path for results.")
    parser.add_argument("--setting",type=str, required=True, help="Setting for the experiment (e.g., 'observational', 'demographic').")
    parser.add_argument("--n", type=int, default=2000, help="Number of samples.")
    parser.add_argument("--d", type=int, default=5, help="Number of features.")
    parser.add_argument("--tau", type=float, default=4.0, help="Treatment effect within the subgroup.")
    parser.add_argument("--gamma", type=float, default=1.0, help="Treatment effect outside the subgroup.")
    parser.add_argument("--c", type=float, default=2.0, help="Confounding effect within the subgroup.")
    parser.add_argument("--sigma", type=float, default=0.5, help="Standard deviation of the noise in Y.")
    parser.add_argument("--sg_size", type=float, default=0.25, help="Size of the subgroup.")
    parser.add_argument("--rule_size", type=int, default=2, help="Number of features in the rule.")
    parser.add_argument("--n_subgroups", type=int, default=3, help="Number of subgroups to generate.")
    parser.add_argument("--mean_shift",type=bool, default=False, help="Whether to apply mean shift in treatment assignment.")
    return parser.parse_args()

args = parse_args()
method = args.method
outpath = args.outpath
setting = args.setting
n = args.n
d = args.d
tau = args.tau
gamma = args.gamma
c = args.c
sigma = args.sigma
sg_size = args.sg_size
rule_size = args.rule_size
n_subgroups = args.n_subgroups
mean_shift = args.mean_shift

f1s = []
accuracies = []
purities = []
runtimes = []
precisions = []
recalls = []

for seed in range(10):
    if setting == "observational":
        data = gen_observational_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
    elif setting == "demographic":
        data = gen_demographic_data(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
    elif setting == "interventional":
        data = gen_interventional_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
    else:
        raise ValueError(f"Unknown setting: {setting}. Choose 'observational' or 'demographic'.")

    X, s_star, A, Y = data["X"], data["s_star"], data["A"], data["Y"]
    is_discrete = [False for _ in range(X.shape[1])]  # Assuming all features are discrete for this example
    scaler_X = StandardScaler()
    X = scaler_X.fit_transform(X)
    scaler_Y = StandardScaler()
    Y = scaler_Y.fit_transform(Y.reshape(-1, 1)).flatten()
    feature_names = [f"X{i}" for i in range(X.shape[1])]

    t = timeit.default_timer()

    if method == "subcon":
        our_config = Subcon_Config().get_setting_config(setting)
        subgroups, rules, models = run_method_multiple_times(X, X[A == 0], X[A == 1], Y, Y[A == 0], Y[A == 1], scaler_X, scaler_Y, feature_names, is_discrete, our_config, discrete_target=False, maximize=True, plot=False, max_reps=3)
        best_sg = -1
        best_overlap = -1
        masks = []
        for i, (s0_mask, s1_mask) in enumerate(subgroups):
            candidate_sg = np.zeros(X.shape[0], dtype=int)
            A0_indices = np.where(A == 0)[0]
            A1_indices = np.where(A == 1)[0]
            candidate_sg[A0_indices[s0_mask]] = 1
            candidate_sg[A1_indices[s1_mask]] = 1
            overlap = f1_score(s_star, candidate_sg)
            masks.append(candidate_sg)
            if overlap > best_overlap:
                best_overlap = overlap
                best_sg = i
        sg = masks[best_sg]

    elif method == "syflow":
        config = Syflow_Config().get_setting_config(setting)

        X_cat = np.concatenate([X, A.reshape(-1, 1)], axis=1)
        X_cat = torch.tensor(X_cat, dtype=torch.float64)
        limits = get_data_limits(X_cat)
        model = And_Finder(limits)
        Y = Y.squeeze()
        Y = torch.tensor(Y[:,None], dtype=torch.float64)
        flow_population = None
        flow_subgroup = None
        subgroups = []
        subgroup_priors = []

        for _ in range(n_subgroups):
            (flow_population, flow_subgroup), classifier, sg = syflow(X_cat, Y, model,flow_population=flow_population, subgroup_priors=subgroup_priors, progressbar=False, alpha=config["alpha"], lr_classifier=config["lr_classifier"], subgroup_train_epochs=config["subgroup_train_epochs"])
            sg = sg.numpy()
            subgroups.append(sg)
            subgroup_priors.append(flow_subgroup)
        best_sg = -1
        best_overlap = -1
        for i, candidate_sg in enumerate(subgroups):
            overlap = f1_score(s_star, candidate_sg)
            if overlap > best_overlap:
                best_overlap = overlap
                best_sg = i
        sg = subgroups[best_sg]

    elif method == "pysubgroup":
        X = pd.DataFrame(X)
        Y = pd.DataFrame(Y)
        config = PySubgroup_Config().get_setting_config(setting)
        Y.columns = ["Y"]
        X.columns = [f"X{i}" for i in range(X.shape[1])]
        data = pd.concat([X,Y],axis=1)
        target = ps.NumericTarget("Y")
        search_space = ps.create_selectors(data, ignore=["Y"],nbins=config["n_bins"], intervals_only=False)
        task = ps.SubgroupDiscoveryTask (
            data, 
            target, 
            search_space, 
            result_set_size=n_subgroups, 
            qf=ps.StandardQFNumeric(config["alpha"]))

        result = ps.BeamSearch(beam_width=config["beam_width"],beam_width_adaptive=False).execute(task)
        result.to_dataframe()
        subgroups = []
        rules = []
        for i in range(n_subgroups):
            result_string = str(result.to_dataframe().iloc[i]["subgroup"])
            rules.append(replace_feature_names(result_string,feature_names))
            parts = result_string.split(" AND ")
            conditions = []
            for part in parts:
                # parse this: "X0>=0.80" or "X0<0.80"
                if "==" in part:
                    var, val = part.split("==")
                    var = int(var[1:])
                    #if isinstance(val,str):
                    val = convert(data,var, val)
                    conditions.append((var,val,val))
                    continue
                elif "<" in part:
                    var, high = part.split("<")
                    low = - np.inf
                    var = int(var[1:])
                    high = float(high)
                else:
                    var, low = part.split(">=")
                    high = np.inf
                    var = int(var[1:])
                    low = float(low)
                conditions.append((var,low,high))
                
            subgroup_member = np.ones((X.shape[0],),dtype=bool)
            for cond in conditions:
                var, low, high = cond
                var = int(var)
                subgroup_member = np.logical_and(subgroup_member, np.logical_and(X.iloc[:,var]>=low, X.iloc[:,var]<=high))
            subgroups.append(subgroup_member)
        best_sg = -1
        best_overlap = -1
        for i, candidate_sg in enumerate(subgroups):
            overlap = f1_score(s_star, candidate_sg)
            if overlap > best_overlap:
                best_overlap = overlap
                best_sg = i
        sg = subgroups[best_sg]
    elif method == "causaltree":
        config = CausalTree_Config().get_setting_config(setting)
        min_support = int(config["min_samples_leaf"] * X.shape[0])
        model = CausalTreeRegressor(min_samples_leaf=min_support, max_depth=config["max_depth"])

        # Fit the model using features (X), treatment (A), and outcome (Y).
        model.fit(X, treatment=A, y=Y)

        sg = model.apply(X)
        # take label with highest overlap

        best_leaf = -1
        best_overlap = -1
        for leaf in np.unique(sg):
            overlap = f1_score(s_star, (sg == leaf).squeeze())
            if overlap > best_overlap:
                best_overlap = overlap
                best_leaf = leaf
        sg = (sg == best_leaf)
    
    elif method == "honesttree":
        config = HonestTree_Config().get_setting_config(setting)
        # train a tree, i.e. 1 estimator with max_samples=1
        min_support = int(config["min_samples_leaf"] * X.shape[0])
        max_depth = config["max_depth"]
        model = CausalForest(n_estimators=1,max_samples=1., honest=True,min_samples_leaf=min_support, max_depth=max_depth,subforest_size=1,inference=False)
        # Fit the model using features (X), treatment (A), and outcome (Y).
        model.fit(X, A, Y)
        sg = model.apply(X)
        # take label with highest overlap
        best_leaf = -1
        best_overlap = -1
        for leaf in np.unique(sg):
            leaf_sg = (sg == leaf).squeeze()
            overlap = f1_score(s_star, leaf_sg)
            if overlap > best_overlap:
                best_overlap = overlap
                best_leaf = leaf
        sg = (sg == best_leaf).squeeze()
        

    t = timeit.default_timer() - t
    runtimes.append(t)
    fp = np.sum((sg == 1) & (s_star == 0))
    fn = np.sum((sg == 0) & (s_star == 1))
    tp = np.sum((sg == 1) & (s_star == 1))
    tn = np.sum((sg == 0) & (s_star == 0))
    #print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * tp / (2 * tp + fp + fn)
    accuracy = (tp + tn) / (tp + fp + tn + fn)
    purity = tp / (tp + fp) if (tp + fp) > 0 else 0
    print(f"Method:{method}, run: {seed}, F1: {f1}, Accuracy: {accuracy}, Purity: {purity}, Precision: {precision}, Recall: {recall}, Runtime: {t}")

    f1s.append(f1)
    accuracies.append(accuracy)
    purities.append(purity)
    precisions.append(precision)
    recalls.append(recall)



f1 = np.mean(f1s)
accuracy = np.mean(accuracies)
purity = np.mean(purities)
runtime = np.mean(runtimes)
precision = np.mean(precisions)
recall = np.mean(recalls)

f1_std = np.std(f1s)
accuracy_std = np.std(accuracies)
purity_std = np.std(purities)
runtime_std = np.std(runtimes)
precision_std = np.std(precisions)
recall_std = np.std(recalls)

resultfile = open(outpath, "a")
resultfile.write(f"{method};{n};{d};{tau};{f1};{f1_std};{accuracy};{accuracy_std};{purity};{purity_std};{runtime};{runtime_std};{precision};{precision_std};{recall};{recall_std}\n")
resultfile.close()