from gen_synth import *
from run_method import run_method_single_subgroup, run_method_multiple_times
import numpy as np
from sklearn.preprocessing import StandardScaler
import argparse
import torch
from syflow import syflow, And_Finder
from causalml.inference.tree import CausalTreeRegressor # Import from causalml
import pysubgroup as ps
import pandas as pd
from utils import *
import timeit 
from configs import *
from econml.grf import CausalForest
from sklearn.metrics import f1_score

def parse_args():
    parser = argparse.ArgumentParser(description="Run synthetic data experiments.")
    parser.add_argument("--n", type=int, default=2000, help="Number of samples.")
    parser.add_argument("--d", type=int, default=5, help="Number of features.")
    parser.add_argument("--tau", type=float, default=4.0, help="Treatment effect within the subgroup.")
    parser.add_argument("--gamma", type=float, default=1.0, help="Treatment effect outside the subgroup.")
    parser.add_argument("--c", type=float, default=2.0, help="Confounding effect within the subgroup.")
    parser.add_argument("--sigma", type=float, default=0.5, help="Standard deviation of the noise in Y.")
    parser.add_argument("--sg_size", type=float, default=0.25, help="Size of the subgroup.")
    parser.add_argument("--rule_size", type=int, default=2, help="Number of features in the rule.")
    parser.add_argument("--n_subgroups", type=int, default=3, help="Number of subgroups to generate.")
    parser.add_argument("--mean_shift",type=bool, default=False, help="Whether to apply mean shift in treatment assignment.")
    return parser.parse_args()

args = parse_args()
method = "subcon"
n = args.n
d = args.d
tau = args.tau
gamma = args.gamma
c = args.c
sigma = args.sigma
sg_size = args.sg_size
rule_size = args.rule_size
n_subgroups = args.n_subgroups
mean_shift = args.mean_shift


for setting in ["demographic","interventional","observational"]:
    for lambd in [0.0,0.1, 0.2, 0.3, 0.5, 0.7, 1.0]:
        f1s = []
        accuracies = []
        purities = []
        runtimes = []
        precisions = []
        recalls = []
        for seed in range(10):
            if setting == "observational":
                data = gen_observational_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
            elif setting == "demographic":
                data = gen_demographic_data(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
            elif setting == "interventional":
                data = gen_interventional_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
            else:
                raise ValueError(f"Unknown setting: {setting}. Choose 'observational' or 'demographic'.")

            X, s_star, A, Y = data["X"], data["s_star"], data["A"], data["Y"]
            is_discrete = [False for _ in range(X.shape[1])]  # Assuming all features are discrete for this example
            scaler_X = StandardScaler()
            X = scaler_X.fit_transform(X)
            scaler_Y = StandardScaler()
            Y = scaler_Y.fit_transform(Y.reshape(-1, 1)).flatten()
            feature_names = [f"X{i}" for i in range(X.shape[1])]

            t = timeit.default_timer()

            if method == "subcon":
                our_config = Subcon_Config().get_setting_config(setting)
                our_config["lambd"] = lambd
                subgroups, rules, models = run_method_multiple_times(X, X[A == 0], X[A == 1], Y, Y[A == 0], Y[A == 1], scaler_X, scaler_Y, feature_names, is_discrete, our_config, discrete_target=False, maximize=True, plot=False, max_reps=1)
                best_sg = -1
                best_overlap = -1
                masks = []
                for i, (s0_mask, s1_mask) in enumerate(subgroups):
                    candidate_sg = np.zeros(X.shape[0], dtype=int)
                    A0_indices = np.where(A == 0)[0]
                    A1_indices = np.where(A == 1)[0]
                    candidate_sg[A0_indices[s0_mask]] = 1
                    candidate_sg[A1_indices[s1_mask]] = 1
                    overlap = f1_score(s_star, candidate_sg)
                    masks.append(candidate_sg)
                    if overlap > best_overlap:
                        best_overlap = overlap
                        best_sg = i
                sg = masks[best_sg]        

            t = timeit.default_timer() - t
            runtimes.append(t)
            fp = np.sum((sg == 1) & (s_star == 0))
            fn = np.sum((sg == 0) & (s_star == 1))
            tp = np.sum((sg == 1) & (s_star == 1))
            tn = np.sum((sg == 0) & (s_star == 0))
            #print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * tp / (2 * tp + fp + fn)
            accuracy = (tp + tn) / (tp + fp + tn + fn)
            purity = tp / (tp + fp) if (tp + fp) > 0 else 0
            print(f"Method:{method}, run: {seed}, F1: {f1}, Accuracy: {accuracy}, Purity: {purity}, Precision: {precision}, Recall: {recall}, Runtime: {t}")

            f1s.append(f1)
            accuracies.append(accuracy)
            purities.append(purity)
            precisions.append(precision)
            recalls.append(recall)



        f1 = np.mean(f1s)
        accuracy = np.mean(accuracies)
        purity = np.mean(purities)
        runtime = np.mean(runtimes)
        precision = np.mean(precisions)
        recall = np.mean(recalls)

        f1_std = np.std(f1s)
        accuracy_std = np.std(accuracies)
        purity_std = np.std(purities)
        runtime_std = np.std(runtimes)
        precision_std = np.std(precisions)
        recall_std = np.std(recalls)

        outpath = f"results/sensitivity/{setting}_lambda_{lambd}.csv"
        resultfile = open(outpath, "w")
        resultfile.write("F1;F1_std;Accuracy;Accuracy_std;Purity;Purity_std;Runtime;Runtime_std;Precision;Precision_std;Recall;Recall_std\n")
        resultfile.write(f"{f1};{f1_std};{accuracy};{accuracy_std};{purity};{purity_std};{runtime};{runtime_std};{precision};{precision_std};{recall};{recall_std}\n")
        resultfile.close()
    
    for gamma in [0.0, 0.1,0.2,0.3,0.5,0.7,1.0]:
        f1s = []
        accuracies = []
        purities = []
        runtimes = []
        precisions = []
        recalls = []
        for seed in range(10):
            if setting == "observational":
                data = gen_observational_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
            elif setting == "demographic":
                data = gen_demographic_data(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
            elif setting == "interventional":
                data = gen_interventional_trial(n=n, d=d, tau=tau, gamma=gamma, c=c, sigma=sigma, sg_size=sg_size, rule_size=rule_size, seed=seed, mean_shift=mean_shift)
            else:
                raise ValueError(f"Unknown setting: {setting}. Choose 'observational' or 'demographic'.")

            X, s_star, A, Y = data["X"], data["s_star"], data["A"], data["Y"]
            is_discrete = [False for _ in range(X.shape[1])]  # Assuming all features are discrete for this example
            scaler_X = StandardScaler()
            X = scaler_X.fit_transform(X)
            scaler_Y = StandardScaler()
            Y = scaler_Y.fit_transform(Y.reshape(-1, 1)).flatten()
            feature_names = [f"X{i}" for i in range(X.shape[1])]

            t = timeit.default_timer()

            if method == "subcon":
                our_config = Subcon_Config().get_setting_config(setting)
                our_config["gamma"] = gamma
                subgroups, rules, models = run_method_multiple_times(X, X[A == 0], X[A == 1], Y, Y[A == 0], Y[A == 1], scaler_X, scaler_Y, feature_names, is_discrete, our_config, discrete_target=False, maximize=True, plot=False, max_reps=3)
                best_sg = -1
                best_overlap = -1
                masks = []
                for i, (s0_mask, s1_mask) in enumerate(subgroups):
                    candidate_sg = np.zeros(X.shape[0], dtype=int)
                    A0_indices = np.where(A == 0)[0]
                    A1_indices = np.where(A == 1)[0]
                    candidate_sg[A0_indices[s0_mask]] = 1
                    candidate_sg[A1_indices[s1_mask]] = 1
                    overlap = f1_score(s_star, candidate_sg)
                    masks.append(candidate_sg)
                    if overlap > best_overlap:
                        best_overlap = overlap
                        best_sg = i 
                sg = masks[best_sg]
            t = timeit.default_timer() - t
            runtimes.append(t)
            fp = np.sum((sg == 1) & (s_star == 0))
            fn = np.sum((sg == 0) & (s_star == 1))  
            tp = np.sum((sg == 1) & (s_star == 1))
            tn = np.sum((sg == 0) & (s_star == 0))
            #print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            accuracy = (tp + tn) / (tp + fp + tn + fn)
            purity = tp / (tp + fp) if (tp + fp) > 0 else 0
            print(f"Method:{method}, run: {seed}, F1: {f1}, Accuracy: {accuracy}, Purity: {purity}, Precision: {precision}, Recall: {recall}, Runtime: {t}") 
            f1s.append(f1)
            accuracies.append(accuracy)
            purities.append(purity)
            precisions.append(precision)
            recalls.append(recall)
            runtimes.append(t)  
        f1 = np.mean(f1s)
        accuracy = np.mean(accuracies)
        purity = np.mean(purities)
        runtime = np.mean(runtimes)
        precision = np.mean(precisions)
        recall = np.mean(recalls)
        f1_std = np.std(f1s)
        accuracy_std = np.std(accuracies)
        purity_std = np.std(purities)
        runtime_std = np.std(runtimes)
        precision_std = np.std(precisions)
        recall_std = np.std(recalls)
        outpath = f"results/sensitivity/{setting}_gamma_{gamma}.csv"
        resultfile = open(outpath, "w")
        resultfile.write("F1;F1_std;Accuracy;Accuracy_std;Purity;Purity_std;Runtime;Runtime_std;Precision;Precision_std;Recall;Recall_std\n")
        resultfile.write(f"{f1};{f1_std};{accuracy};{accuracy_std};{purity};{purity_std};{runtime};{runtime_std};{precision};{precision_std};{recall};{recall_std}\n")
        resultfile.close()