
import numpy as np
import pandas as pd
import torch
import argparse
import os
import time
from scipy.stats import chi2, chi2_contingency
from scipy.stats import kruskal, combine_pvalues

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append('./src')

from data import PPCI

def get_parser():
    parser = argparse.ArgumentParser(description="t-REx")
    parser.add_argument("--encoder", type=str, default="dino", help="Encoder to use")
    parser.add_argument("--sc", type=str, default="experiment", help="Split criteria")
    parser.add_argument("--version", type=str, default="hq", help="Version of the dataset")
    parser.add_argument("--preprocessed", type=str, default="original", help="Preprocessed data")
    parser.add_argument("--task", type=str, default="or", help="Task to perform")
    return parser


def main(args):
    batch_sizes = [64, 128, 256]
    preprocesseds = ["original", "preprocessed/0.85"]
    hls = [1,2]
    lrs = [0.05, 0.01, 0.005, 0.001, 0.0005]
    seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    methods = ["DERM", "ERM", "IRM", "vREx"]
    results = []
    n_runs = len(lrs) * len(seeds) * len(methods)* len(hls)* len(preprocesseds) * len(batch_sizes)
    i = 0
    start = time.time()
    for preprocessed in preprocesseds:
        print(f"preprocessed: {preprocessed}", flush=True)
        dataset = PPCI(encoder = args.encoder,
                    token = "class",
                    task = args.task,
                    split_criteria = args.sc,
                    environment = "supervised",
                    batch_size = 64,
                    num_proc = 4,
                    verbose = False,
                    data_dir = f'data/istant_{args.version}',
                    preprocessed = preprocessed,
                    results_dir = f'results/istant_{args.version}')
        for batch_size in batch_sizes:
            print(f"batch_size: {batch_size}", flush=True)
            for hl in hls:
                print(f"hidden_layer: {hl}", flush=True)
                for lr in lrs:
                    print(f"lr: {lr}", flush=True)
                    for seed in seeds:
                        print(f"seed: {seed}", flush=True)
                        for method in methods:
                            print(f"method: {method}", flush=True)
                            if i>1:
                                partial = time.time()-start
                                total = partial/i * n_runs
                                print(f"Run: {i+1}/{n_runs}; Time: {int(partial) // 60}m{int(partial) % 60}s/{int(total) // 60}m{int(total) % 60}s", flush=True)
                            else:
                                print(f"Run: {i+1}/{n_runs}")
                            dataset.train(add_pred_env="supervised",
                                        hidden_layers = hl,
                                        hidden_nodes = 256,
                                        batch_size = batch_size,
                                        lr = lr,
                                        seed = seed,
                                        num_epochs=15,
                                        save = False,
                                        verbose= False,
                                        gpu = True,
                                        force = False,
                                        cfl = 0,
                                        ic_weight = 10,
                                        method = method)
                            Y = dataset.supervised["Y"]
                            Y_hat = dataset.supervised["Y_hat"]
                            env = dataset.supervised["source_data"]["experiment"]
                            pos = dataset.supervised["source_data"]["position"]
                            # W = torch.stack([env, pos], dim=1).numpy()
                            W = env*9+pos
                            T = dataset.supervised["T"]
                            metrics = dataset.evaluate()
                            results.append({
                                "hl": hl,
                                "preprocessed": preprocessed,
                                "batch_size": batch_size,
                                "lr": lr,
                                "seed": seed,
                                "method": method,
                                "acc": metrics["acc"],
                                "bacc": metrics["bacc"],
                                "ATE": metrics["ATE"],
                                "PPATE": metrics["PPATE"],
                                "cond_ind_T": conditional_independence_test(Y_hat, T, Y),
                                "cond_ind_W": conditional_independence_test(Y_hat, W, Y),
                                "cond_ind_Y": conditional_independence_test(Y_hat, Y, T),
                            })
                            i += 1
    results = pd.DataFrame(results)
    results_dir = f"results/ants/trex"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    results.to_csv(f"{results_dir}/{args.sc}.csv", index=False)

    results_ = results[results["acc"] > 0.8].copy()
    results_["TERB"] = abs(results_["ATE"] - results_["PPATE"])/results_["ATE"]*100
    results_["cond_ind_T"] = results_["cond_ind_T"].apply(lambda x: "Confounded" if x < 0.05 else "Unconfounded")

    # Set seaborn style
    sns.set(style="whitegrid", palette="colorblind")

    # Create the figure with subplots
    fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True)

    # Plot 1: Accuracy
    sns.barplot(x="cond_ind_T", y="acc", data=results_, ax=axes[0], ci="sd", alpha=0.8)
    sns.stripplot(x="cond_ind_T", y="acc", data=results_, ax=axes[0], color='black', size=5, jitter=True, alpha=0.6)
    axes[0].set_ylabel("Accuracy")
    axes[0].set_ylim(0, 1.0)
    axes[0].set_xlabel("")
    # Plot 2: TEAB
    sns.barplot(x="cond_ind_T", y="TERB", data=results_, ax=axes[1], ci="sd", alpha=0.8)
    sns.stripplot(x="cond_ind_T", y="TERB", data=results_, ax=axes[1], color='black', size=5, jitter=True, alpha=0.6)
    axes[1].set_ylabel("|TERB| %")
    axes[1].set_xlabel("")
    # Adjust layout
    plt.tight_layout()
    plt.savefig(f"{results_dir}/{args.sc}_T.png", dpi=300, bbox_inches='tight')
    plt.show();
    # Create the figure with subplots
    fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True)

    # Plot 1: Accuracy
    sns.barplot(x="cond_ind_W", y="acc", data=results_, ax=axes[0], ci="sd", alpha=0.8)
    sns.stripplot(x="cond_ind_W", y="acc", data=results_, ax=axes[0], color='black', size=5, jitter=True, alpha=0.6)
    axes[0].set_ylabel("Accuracy")
    axes[0].set_ylim(0, 1.0)
    axes[0].set_xlabel("")
    # Plot 2: TEAB
    sns.barplot(x="cond_ind_W", y="TERB", data=results_, ax=axes[1], ci="sd", alpha=0.8)
    sns.stripplot(x="cond_ind_W", y="TERB", data=results_, ax=axes[1], color='black', size=5, jitter=True, alpha=0.6)
    axes[1].set_ylabel("|TERB| %")
    axes[1].set_xlabel("")
    # Adjust layout
    plt.tight_layout()
    plt.savefig(f"{results_dir}/{args.sc}_W.png", dpi=300, bbox_inches='tight')
    plt.show();

def conditional_independence_test(y_hat, w, y):
    """
    Tests whether Y_hat ⟂ W | Y.
    
    Parameters:
    y_hat : array-like, shape (n,)      - continuous scores in [0,1]
    w     : array-like, shape (n,)      - categorical variable {0, ..., 44}
    y     : array-like, shape (n,)      - binary variable {0,1}
    
    Returns:
    p_value : float or None - p-value testing Y_hat ⟂ W | Y, or None if test couldn't be performed
    """
    y_hat = np.asarray(y_hat)
    w = np.asarray(w)
    y = np.asarray(y)
    
    p_values = []
    
    for y_val in [0, 1]:
        mask = (y == y_val)
        if np.sum(mask) < 2:
            continue  # not enough samples
        df = pd.DataFrame({'y_hat': y_hat[mask], 'w': w[mask]})
        groups = [group['y_hat'].values for _, group in df.groupby('w') if len(group) > 1]
        
        # Remove groups with identical values
        groups = [group for group in groups if len(np.unique(group)) > 1]
        
        if len(groups) > 1:
            stat, p = kruskal(*groups)
            p_values.append(p)
        else:
            p_values.append(None)  # Return None if test can't be performed
    
    # If no valid p-values could be calculated, return None
    if all(p is None for p in p_values):
        return None
    
    # Otherwise, combine p-values if valid ones exist
    valid_p_values = [p for p in p_values if p is not None]
    if len(valid_p_values) == 0:
        return None  # No valid p-values available
    
    stat, combined_p = combine_pvalues(valid_p_values, method='fisher')
    return combined_p


if __name__ == "__main__":
    args = get_parser().parse_args()
    main(args)