import argparse
import ast
from multiprocessing import Pool

import numpy as np
import pandas as pd
from scipy.stats import pearsonr
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.linear_model import Ridge
from sklearn.model_selection import KFold
from tqdm import tqdm

import wandb

wandb.require("core")


def compute_lambda(Sigma):
    try:
        Sigma_inv = np.linalg.inv(Sigma)
    except np.linalg.LinAlgError:
        print("Singular matrix encountered, applying damping.")
        Sigma += 1e-6 * np.eye(Sigma.shape[0])
    ones = np.ones(Sigma.shape[0])
    Sigma_inv = np.linalg.inv(Sigma)
    lambda_star = Sigma_inv @ ones / (ones.T @ Sigma_inv @ ones)
    return lambda_star


def compute_aipw_cf(X, Y, T, kf, pi, alpha_ridge):
    psi_rct = np.zeros(X.shape[0])
    for train_idx, test_idx in kf.split(X):
        X_train, X_test = X[train_idx], X[test_idx]
        Y_train, T_train = Y[train_idx], T[train_idx]

        mu1_rct = Ridge(alpha=alpha_ridge)
        mu0_rct = Ridge(alpha=alpha_ridge)

        mu1_rct.fit(X_train[T_train == 1], Y_train[T_train == 1])
        mu0_rct.fit(X_train[T_train == 0], Y_train[T_train == 0])

        psi_rct[test_idx] = (
            mu1_rct.predict(X_test) - mu0_rct.predict(X_test)
            + (T[test_idx] * (Y[test_idx] - mu1_rct.predict(X_test))) / pi
            - ((1 - T[test_idx]) * (Y[test_idx] - mu0_rct.predict(X_test))) / (1 - pi)
        )
    return psi_rct.mean(), psi_rct


def compute_haipw(X_rct, T_rct, Y_rct, df, model_list):
    model_results = {model: {"psi": np.zeros(X_rct.shape[0])} for model in model_list}
    pi = T_rct.mean()
    for model in model_list:
        y1_hat = df["Y1_" + model].to_numpy()
        y0_hat = df["Y0_" + model].to_numpy()
        model_results[model]["psi"] = (
            (T_rct * (Y_rct - y1_hat)) / pi
            - ((1 - T_rct) * (Y_rct - y0_hat)) / (1 - pi)
            + y1_hat - y0_hat
        )
    all_estimates = np.array([model_results[model]["psi"] for model in model_list])
    mean_estimates = np.mean(all_estimates, axis=1)
    return mean_estimates, all_estimates


def compute_ppi(X_rct, T_rct, Y_rct, df):
    model = "gpt4o"
    f = df.get("Y0_" + model, df[[col for col in df.columns if col.startswith("Y0_")][0]]).to_numpy()

    sigma_t = np.std(Y_rct[T_rct == 1], ddof=1)
    sigma_c = np.std(Y_rct[T_rct == 0], ddof=1)
    sigma_f = np.std(f, ddof=1)

    pi_t = T_rct.mean()
    pi_c = 1 - pi_t

    rho_c = pearsonr(f[T_rct == 0], Y_rct[T_rct == 0])[0] if np.sum(T_rct == 0) > 0 else 0.0
    rho_t = pearsonr(f[T_rct == 1], Y_rct[T_rct == 1])[0] if np.sum(T_rct == 1) > 0 else 0.0

    optimal_lambda = (pi_c * sigma_t * rho_t + pi_t * sigma_c * rho_c) / (sigma_f + 1e-3)
    ppi_est = np.mean(Y_rct[T_rct == 1] - optimal_lambda * f[T_rct == 1]) - np.mean(Y_rct[T_rct == 0] - optimal_lambda * f[T_rct == 0])

    sigma_t_sq = np.var(Y_rct[T_rct == 1], ddof=1)
    sigma_c_sq = np.var(Y_rct[T_rct == 0], ddof=1)
    sigma_f_sq = np.var(f, ddof=1)
    ppi_var = sigma_t_sq / pi_t + sigma_c_sq / pi_c - optimal_lambda**2 * sigma_f_sq * (1 / pi_c + 1 / pi_t)

    return ppi_est, ppi_var


def run_experiment(df_augmented, models, n_features, n_rct, n_folds, alpha_ridge, gt, seed):
    df_augmented = df_augmented.groupby('T').sample(n=n_rct // 2, random_state=seed)
    # Subsample rct for the estimators
    Y_rct = df_augmented["Y"].to_numpy()
    T_rct = df_augmented["T"].to_numpy()
    X_rct = df_augmented.drop(columns=["T", "Y", "Unnamed: 0"] + [f"Y1_{model}" for model in models]
                              + [f"Y0_{model}" for model in models]).to_numpy()

    # Select number of features and set true propensity score
    feature_selector = SelectKBest(score_func=f_regression, k=n_features)
    X_rct = feature_selector.fit_transform(X_rct, Y_rct)
    kf = KFold(n_splits=n_folds)

    # DiM estimate
    pi = T_rct.mean()
    dm_i = Y_rct * T_rct / pi - Y_rct * (1 - T_rct) / (1 - pi)
    dm_est = dm_i.mean()
    dm_var = np.var(Y_rct[T_rct == 1], ddof=1) / pi + np.var(Y_rct[T_rct == 0], ddof=1) / (1 - pi)

    # AIPW and PPCT estimates
    _, psi1 = compute_aipw_cf(X_rct, Y_rct, T_rct, kf, pi, alpha_ridge)
    aipw_var = np.var(psi1, ddof=1)
    aipw_est = np.mean(psi1)
    ppi_est, ppi_var = compute_ppi(X_rct=X_rct, T_rct=T_rct, Y_rct=Y_rct, df=df_augmented)

    # HAIPW estimate
    _, psi_haipw = compute_haipw(X_rct=X_rct, T_rct=T_rct, Y_rct=Y_rct, df=df_augmented, model_list=models)
    Sigma = np.cov(np.vstack((psi1, psi_haipw)))
    lambda_star = compute_lambda(Sigma)
    haipw_if = (np.vstack((psi1, psi_haipw)) * lambda_star[:, np.newaxis]).sum(axis=0)
    haipw_var = lambda_star.T @ Sigma @ lambda_star
    haipw_est = np.mean(haipw_if)

    zalpha = 1.96
    coverage_aipw = int((aipw_est - zalpha * np.sqrt(aipw_var / n_rct) < gt)
                        and (gt < aipw_est + zalpha * np.sqrt(aipw_var / n_rct))) * 100

    coverage_dm = int((dm_est - zalpha * np.sqrt(dm_var / n_rct) < gt)
                      and (gt < dm_est + zalpha * np.sqrt(dm_var / n_rct))) * 100

    coverage_haipw = int((haipw_est - zalpha * np.sqrt(haipw_var / n_rct) < gt)
                         and (gt < haipw_est + zalpha * np.sqrt(haipw_var / n_rct))) * 100

    coverage_ppi = int((ppi_est - zalpha * np.sqrt(ppi_var / n_rct) < gt)
                       and (gt < ppi_est + zalpha * np.sqrt(ppi_var / n_rct))) * 100

    return haipw_var, aipw_var, ppi_var, dm_var, coverage_haipw, coverage_aipw, coverage_ppi, coverage_dm


if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Run")
    parser.add_argument("--n_rct", type=int, default=30, help="Number of samples in RCT")
    parser.add_argument("--n_features", type=int, default=5)
    parser.add_argument("--n_folds", type=int, default=20)
    parser.add_argument("--alpha_ridge", type=float, default=0.1)
    parser.add_argument("--study", type=str, required=True)
    parser.add_argument("--model", nargs="+", type=str, required=True, help="e.g., 'gpt4o llama claude_haiku'")
    parser.add_argument("--n_seeds", type=int, default=10)
    parser.add_argument("--n_prompts", type=int, default=1)
    parser.add_argument("--name_wandb", default=None)

    args = parser.parse_args()
    wandb.init(entity="test", project=f"{args.name_wandb}", reinit=True, config=args)

    # Read the csv file and compute ground truth
    df = pd.read_csv(f"{args.study}/df_processed.csv")
    Y = df["Y"].to_numpy()
    T = df["T"].to_numpy()
    gt = Y[T == 1].mean() - Y[T == 0].mean()

    # Augment DataFrame with LLM outcomes
    df_augmented = df.copy()
    for model in args.model:
        df_model = pd.read_csv(f"{args.study}/df_{model}.csv")[["Y0hat_responses", "Y1hat_responses"]]
        Y0_model = []
        Y1_model = []
        for i, row in df_model.iterrows():
            Y0_list = ast.literal_eval(row["Y0hat_responses"])
            Y1_list = ast.literal_eval(row["Y1hat_responses"])
            sampled_Y0 = Y0_list[:args.n_prompts]
            sampled_Y1 = Y1_list[:args.n_prompts]
            Y0_model.append(np.mean(sampled_Y0))
            Y1_model.append(np.mean(sampled_Y1))
        df_model[f"Y0_{model}"] = Y0_model
        df_model[f"Y1_{model}"] = Y1_model
        df_model = df_model.drop(columns=["Y0hat_responses", "Y1hat_responses"])
        df_augmented = pd.concat([df_model, df_augmented], axis=1)

    n_jobs = 50
    print(f'running n_jobs: {n_jobs}')

    def run_single_exp(seed):
        return run_experiment(df_augmented, args.model, args.n_features, args.n_rct,
                              args.n_folds, args.alpha_ridge, gt, seed)

    with Pool(n_jobs) as pool:
        results = list(tqdm(pool.imap(run_single_exp, range(args.n_seeds)), total=args.n_seeds))

    # Unpack results
    haipw_vars, aipw_vars, ppi_vars, dm_vars, coverage_haipws, coverage_aipws, coverage_ppis, coverage_dms = zip(*results)

    # Compute averages
    avg_aipw = np.mean(aipw_vars)
    avg_haipw = np.mean(haipw_vars)
    avg_ppi = np.mean(ppi_vars)
    avg_dm = np.mean(dm_vars)
    avg_coverage_aipw = np.mean(coverage_aipws)
    avg_coverage_haipw = np.mean(coverage_haipws)
    avg_coverage_ppi = np.mean(coverage_ppis)
    avg_coverage_dm = np.mean(coverage_dms)

    wandb.log({
        'aipw_var': avg_aipw,
        'ppi_var': avg_ppi,
        'haipw_var': avg_haipw,
        'dm_var': avg_dm,
        'coverage_aipw': avg_coverage_aipw,
        'coverage_haipw': avg_coverage_haipw,
        'coverage_ppi': avg_coverage_ppi,
        'coverage_dm': avg_coverage_dm,
        'ground_truth': gt
    })
    wandb.finish()
