"""
This script implements a sequential ablation framework for evaluating language model
metrics using online sequential hypothesis testing with Kernel MMD (Maximum Mean
Discrepancy). It processes precomputed metrics from training, validation, and test
splits, applies normalization and outlier handling, trains simple classifiers,
and runs online sequential tests to detect distributional differences between
data sources. The script outputs detailed CSV reports with statistical metrics,
test outcomes, and performance traces.

How to run:
1. Prepare metric JSON files in the expected structure:
   ./metrics_ref/{MODEL_NAME}/{DATASET_NAME}/train_metrics.json
   ./metrics_ref/{MODEL_NAME}/{DATASET_NAME}/val_metrics.json
   (optional) ./metrics_ref/{MODEL_NAME}/{DATASET_NAME}/test_metrics.json

2. Run the script from the command line:
   python script.py --output_dir ./results --dataset_name wikipedia \
       --model_name EleutherAI/pythia-12b-deduped --num_samples 2000

   Key arguments:
   - --features ["all" | "selected"] : Choose feature subset.
   - --normalize ["no" | "train" | "combined"] : Normalization strategy.
   - --outliers ["keep" | "zero" | "mean" | "clip" | ...] : Outlier handling.
   - --kernel ["poly" | "linear"] : Kernel choice for MMD.
   - --num_random : Number of random seeds/shuffles to run.
   - --output_dir : Directory where result CSV files are saved.

3. Results are saved as CSV files in:
   {output_dir}/{dataset_name}/e_process_info_dataset_{dataset_name}_ONLINE_seed_{i}.csv
"""


import os
import json
import argparse
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from selected_features import feature_list

import src.kernels as kernels
from src.kernelMMD import kernelMMDprediction
from src.SeqTestsUtils import ONSstrategy, runSequentialTest


MIN_MMD_POINTS = 20  # enforce minimum of 20 data points for Kernel MMD


def get_args():
    parser = argparse.ArgumentParser(description='Sequential Ablation')
    parser.add_argument('--model_name', type=str, default="EleutherAI/pythia-12b-deduped", help='Name of the backbone LM (for paths/labels)')
    parser.add_argument('--dataset_name', type=str, default="wikipedia", help='Dataset name')
    parser.add_argument('--num_samples', type=int, default=2000, help='Number of samples used to build splits')
    parser.add_argument('--normalize', type=str, default="train", choices=["no", "train", "combined"])
    parser.add_argument('--outliers', type=str, default="mean", choices=["randomize", "keep", "zero", "mean", "clip", "mean+p-value", "p-value"])
    parser.add_argument('--features', type=str, default="selected", choices=["all", "selected"])

    # e-process / sequential test arguments
    parser.add_argument('--lambda_max', type=float, default=0.8, help="Cap for ONS bets")
    parser.add_argument('--num_trials', type=int, default=250, help="Number of random trials in sequential test")
    parser.add_argument('--kernel', type=str, default="poly", choices=["poly","linear"], help="Kernel for Kernel MMD")
    parser.add_argument('--alpha', type=float, default=0.05, help="Significance level")
    parser.add_argument('--post_processing', type=str, default="tanh", choices=["tanh", "sinh","arctan","delapena"])

    parser.add_argument('--num_random', type=int, default=50, help="Number of random seeds/shuffles")
    parser.add_argument('--metrics_path', type=str, default="./metrics_ref/EleutherAI_pythia-12b-deduped", help="Path to derived metrics")
    parser.add_argument('--output_dir', type=str, required=True, help='Where to save outputs')
    parser.add_argument('--save_file_prefix', type=str, default="", help='Prefix for output files')

    parser.add_argument('--online_epochs', type=int, default=30, help='Epochs per online update step')
    parser.add_argument('--warmup_mode', type=str, default="bet0", choices=["bet0","centroid","none"],
                        help='Warm start for earliest rounds before adequate data (ONLINE only)')
    parser.add_argument('--warmup_K', type=int, default=1, help='Number of warmup rounds (ONLINE only)')
    args = parser.parse_args()
    return args


def get_model(num_features, linear=True):
    if linear:
        return nn.Linear(num_features, 1)
    else:
        return nn.Sequential(
            nn.Linear(num_features, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )


def train_model(inputs, y, num_epochs=100, lr=0.01):
    num_features = inputs.shape[1]
    model = get_model(num_features)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    y_float = y.float()

    for _ in range(num_epochs):
        optimizer.zero_grad()
        logits = model(inputs).squeeze()
        loss = criterion(logits, y_float)
        loss.backward()
        optimizer.step()
    return model


def get_predictions(model, X, y):
    with torch.no_grad():
        preds = model(X).detach().squeeze()
    loss = nn.BCEWithLogitsLoss()(preds, y.float())
    return preds.numpy(), loss.item()


def get_dataset_splits(_train_metrics, _val_metrics, num_samples):

    for_train_train_metrics = _train_metrics[:num_samples]
    for_train_val_metrics   = _val_metrics[:num_samples]
    for_val_train_metrics   = _train_metrics[num_samples:]
    for_val_val_metrics     = _val_metrics[num_samples:]

    train_x = np.concatenate((for_train_train_metrics, for_train_val_metrics), axis=0)
    train_y = np.concatenate((-1*np.zeros(for_train_train_metrics.shape[0]), np.ones(for_train_val_metrics.shape[0])))
    val_x   = np.concatenate((for_val_train_metrics, for_val_val_metrics), axis=0)
    val_y   = np.concatenate((-1*np.zeros(for_val_train_metrics.shape[0]),   np.ones(for_val_val_metrics.shape[0])))

    return (torch.tensor(train_x, dtype=torch.float32),
            torch.tensor(train_y, dtype=torch.float32)), \
           (torch.tensor(val_x,   dtype=torch.float32),
            torch.tensor(val_y,   dtype=torch.float32))


def normalize_and_stack_fix(train_metrics, val_metrics, normalize="train", epsilon=1e-8):
    new_train_metrics, new_val_metrics = [], []
    for i, (tm, vm) in enumerate(zip(train_metrics, val_metrics)):
        tm = np.array(tm); vm = np.array(vm)
        if len(tm) == 0 or len(vm) == 0:
            continue
        if normalize == "combined":
            comb = np.concatenate((tm, vm))
            mean_tm = np.nanmean(comb); std_tm = np.nanstd(comb)
        else:
            mean_tm = np.nanmean(tm); std_tm = np.nanstd(tm)
        if std_tm == 0 or np.isnan(std_tm) or std_tm < epsilon:
            std_tm = epsilon
        if normalize == "no":
            ntm, nvm = tm, vm
        else:
            ntm = (tm - mean_tm) / std_tm
            nvm = (vm - mean_tm) / std_tm
        if np.any(np.isnan(ntm)) or np.any(np.isnan(nvm)):
            continue
        new_train_metrics.append(ntm)
        new_val_metrics.append(nvm)

    if len(new_train_metrics) == 0:
        raise ValueError("No data after normalization!")
    return np.stack(new_train_metrics, axis=1), np.stack(new_val_metrics, axis=1)


def remove_outliers(metrics, remove_frac=0.05, outliers="zero"):
    sorted_ids = np.argsort(metrics)
    total = len(metrics)
    k = int(total * remove_frac / 2)
    if k * 2 > total:
        raise ValueError("remove_frac too large")
    low = sorted_ids[:k]; high = sorted_ids[-k:]
    ids = np.concatenate((low, high))
    trimmed = np.copy(metrics)
    if outliers == "zero":
        trimmed[ids] = 0
    elif outliers in ["mean","mean+p-value"]:
        trimmed[ids] = np.mean(trimmed)
    elif outliers == "clip":
        trimmed[high] = trimmed[high[0]]
        trimmed[low]  = trimmed[low[-1]]
    elif outliers == "randomize":
        trimmed = np.delete(trimmed, ids)
    else:
        pass
    return trimmed


# def split_train_val(metrics):
#     keys = list(metrics.keys())
#     n = len(metrics[keys[0]])
#     ids_train = np.random.choice(n, n//2, replace=False)
#     ids_val = np.array([i for i in range(n) if i not in ids_train])
#     mt, mv = {}, {}
#     for k in keys:
#         mt[k] = np.array(metrics[k])[ids_train]
#         mv[k] = np.array(metrics[k])[ids_val]
#     return mt, mv

def split_train_val(metrics):
    keys = list(metrics.keys())
    lengths = [len(metrics[key]) for key in keys]
    num_elements = min(lengths)
    print (f"Using {num_elements} elements")
    # select a random subset of val_metrics (50% of ids)
    ids_train = np.random.choice(num_elements, num_elements//2, replace=False)
    ids_val = np.array([i for i in range(num_elements) if i not in ids_train])
    new_metrics_train = {}
    new_metrics_val = {}
    for key in keys:
        new_metrics_train[key] = np.array(metrics[key])[ids_train]
        new_metrics_val[key] = np.array(metrics[key])[ids_val]
    return new_metrics_train, new_metrics_val

def merge_metrics_dicts(d1, d2):
    k1, k2 = set(d1.keys()), set(d2.keys())
    if k1 != k2:
        raise ValueError("Metrics key mismatch.")
    merged = {}
    for k in d1.keys():
        merged[k] = np.concatenate([np.asarray(d1[k]), np.asarray(d2[k])], axis=0).tolist()
    return merged


def save_sequential_results_csv(
    power, stopped, stop_times, W,
    power_fp, stopped_fp, stop_times_fp, W_fp,
    kernel_name, alpha, num_trials, csv_path
):
    power      = np.asarray(power)
    stopped    = np.asarray(stopped, dtype=bool)
    stop_times = np.asarray(stop_times)
    W          = np.asarray(W)

    power_fp      = np.asarray(power_fp)
    stopped_fp    = np.asarray(stopped_fp, dtype=bool)
    stop_times_fp = np.asarray(stop_times_fp)
    W_fp          = np.asarray(W_fp)

    tp = stopped.sum()
    fn = (~stopped).sum()
    fp = stopped_fp.sum()
    tn = (~stopped_fp).sum()

    tpr = tp/(tp+fn) if (tp+fn) else np.nan
    fpr = fp/(fp+tn) if (fp+tn) else np.nan
    fnr = fn/(tp+fn) if (tp+fn) else np.nan

    max_len = max(len(power), len(W), len(power_fp), len(W_fp))
    idx = np.arange(max_len)
    df_series = pd.DataFrame({
        "sample_index": idx,
        "power": np.pad(power, (0, max_len - len(power)), constant_values=np.nan),
        "wealth": np.pad(W, (0, max_len - len(W)), constant_values=np.nan),
        "power_fp": np.pad(power_fp, (0, max_len - len(power_fp)), constant_values=np.nan),
        "wealth_fp": np.pad(W_fp, (0, max_len - len(W_fp)), constant_values=np.nan),
    })

    df_meta = pd.DataFrame({
        "kernel": [kernel_name],
        "alpha": [alpha],
        "num_trials": [num_trials],
        "TPR": [tpr], "FPR": [fpr], "FNR": [fnr],
        "TP": [tp], "FN": [fn], "FP": [fp], "TN": [tn],
        "stopped_mean": [np.mean(stopped)],
        "stopped_fp_mean": [np.mean(stopped_fp)],
        "avg_stop_time": [np.nanmean(stop_times)],
        "avg_stop_time_fp": [np.nanmean(stop_times_fp)],
    })

    os.makedirs(os.path.dirname(csv_path), exist_ok=True)
    with open(csv_path, "w") as f:
        df_meta.to_csv(f, index=False, float_format="%.17g")
        f.write("\n")
        df_series.to_csv(f, index=False, float_format="%.17g")
    print(f"[INFO] Saved {csv_path}")


def build_stream_pairs_from_split(val_x, val_y):
    """
    From (val_x, val_y) build a stream of per-round pairs:
      X_stream_train[t] is a sample with label 0 (train-pool side)
      X_stream_val[t]   is a sample with label 1 (val-pool side)
    """
    idx_train = torch.where(val_y == 0)[0].cpu().numpy()
    idx_val   = torch.where(val_y == 1)[0].cpu().numpy()
    T = min(len(idx_train), len(idx_val))
    if T == 0:
        raise ValueError("No balanced pairs available in validation split.")
    X_stream_train = val_x[idx_train[:T]].cpu().numpy()
    X_stream_val   = val_x[idx_val[:T]].cpu().numpy()
    return X_stream_train, X_stream_val


def build_online_score_sequences(X_train_pairs, X_val_pairs,
                                 num_epochs=100, warmup_mode="bet0", warmup_K=1):
    """
    ONLINE condition: at round t, model is trained on pairs [0..t-1],
    then scores the current pair t. Continues training over time.
    """
    T, D = X_train_pairs.shape
    device = torch.device("cpu")

    S_train = np.zeros(T, dtype=np.float32)
    S_val   = np.zeros(T, dtype=np.float32)

    model = get_model(D).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    seen_X = []
    seen_y = []

    def tensorize(a): return torch.tensor(a, dtype=torch.float32, device=device)

    for t in range(T):
        if t < warmup_K and warmup_mode == "bet0":
            S_train[t] = 0.0
            S_val[t]   = 0.0
        elif t < warmup_K and warmup_mode == "centroid":
            if len(seen_X) > 0:
                X_seen = np.vstack(seen_X)
                y_seen = np.array(seen_y)
                if np.any(y_seen == 0.0):
                    mu_neg = X_seen[y_seen == 0.0].mean(axis=0)
                else:
                    mu_neg = np.zeros(D)
                if np.any(y_seen == 1.0):
                    mu_pos = X_seen[y_seen == 1.0].mean(axis=0)
                else:
                    mu_pos = np.zeros(D)
                w = (mu_pos - mu_neg)
                S_train[t] = float(np.dot(X_train_pairs[t] - mu_neg, w))
                S_val[t]   = float(np.dot(X_val_pairs[t]   - mu_neg, w))
            else:
                S_train[t] = 0.0
                S_val[t]   = 0.0
        else:
            if len(seen_X) > 0:
                X = tensorize(np.vstack(seen_X))
                y = tensorize(np.array(seen_y, dtype=np.float32))
                for _ in range(num_epochs):
                    optimizer.zero_grad()
                    logits = model(X).squeeze()
                    loss = criterion(logits, y)
                    loss.backward()
                    optimizer.step()
            with torch.no_grad():
                S_train[t] = model(tensorize(X_train_pairs[t])).item()
                S_val[t]   = model(tensorize(X_val_pairs[t])).item()

        seen_X.append(X_train_pairs[t][None, :]); seen_y.append(0.0)
        seen_X.append(X_val_pairs[t][None, :]);   seen_y.append(1.0)

    return S_train, S_val


def source_from_sequences(X_seq, Y_seq):
    """
    Build a randomized Source callable for runSequentialTest
    that samples Nmax pairs without replacement each time.
    """
    X_seq = np.asarray(X_seq)
    Y_seq = np.asarray(Y_seq)

    def _src(Nmax):
        assert len(X_seq) >= Nmax and len(Y_seq) >= Nmax
        idx = np.random.choice(len(X_seq), size=Nmax, replace=False)
        X = X_seq[idx].reshape(-1, 1)
        Y = Y_seq[idx].reshape(-1, 1)
        return X, Y

    return _src

def truncate_metrics_to_consistent_length(train_metrics, val_metrics):
    """
    Truncate train and val metrics to ensure consistent lengths within each set.
    
    Args:
        train_metrics: List of numpy arrays for training metrics
        val_metrics: List of numpy arrays for validation metrics
    
    Returns:
        tuple: (truncated_train_metrics, truncated_val_metrics)
    """
    # Check and fix length inconsistencies
    train_lengths = [len(metric) for metric in train_metrics]
    val_lengths = [len(metric) for metric in val_metrics]
    
    min_train_length = min(train_lengths)
    min_val_length = min(val_lengths)
    
    if not all(length == min_train_length for length in train_lengths):
        print(f"Warning: Train metrics have inconsistent lengths. Shortest length: {min_train_length}")
        print(f"Train lengths: {train_lengths}")
        train_metrics = [metric[:min_train_length] for metric in train_metrics]
    
    if not all(length == min_val_length for length in val_lengths):
        print(f"Warning: Val metrics have inconsistent lengths. Shortest length: {min_val_length}")
        print(f"Val lengths: {val_lengths}")
        val_metrics = [metric[:min_val_length] for metric in val_metrics]
    
    return train_metrics, val_metrics

def main():
    args = get_args()

    base_path = os.path.join(args.metrics_path, args.dataset_name)
    train_metrics_path = os.path.join(base_path, "train_metrics.json")
    val_metrics_path   = os.path.join(base_path, "val_metrics.json")
    test_metrics_path  = os.path.join(base_path, "test_metrics.json")

    with open(train_metrics_path, 'r') as f:
        metrics_train = json.load(f)
    with open(val_metrics_path, 'r') as f:
        metrics_val = json.load(f)
    metrics_test = None
    if os.path.exists(test_metrics_path):
        with open(test_metrics_path, 'r') as f:
            metrics_test = json.load(f)

    if metrics_test is not None:
        nonmember_all = merge_metrics_dicts(metrics_val, metrics_test)
        metrics_train_fp, metrics_val_fp = split_train_val(nonmember_all)
    else:
        metrics_train_fp, metrics_val_fp = split_train_val(metrics_val)

    keys = list(metrics_train.keys())
    train_metrics, val_metrics = [], []
    train_fp_metrics, val_fp_metrics = [], []

    for key in keys:
        if args.features == "selected" and key not in feature_list:
            continue

        mt = np.array(metrics_train[key])
        mv = np.array(metrics_val[key])
        mt_fp = np.array(metrics_train_fp[key])
        mv_fp = np.array(metrics_val_fp[key])

        mt    = remove_outliers(mt,    remove_frac=0.05, outliers=args.outliers)
        mv    = remove_outliers(mv,    remove_frac=0.05, outliers=args.outliers)
        mt_fp = remove_outliers(mt_fp, remove_frac=0.05, outliers=args.outliers)
        mv_fp = remove_outliers(mv_fp, remove_frac=0.05, outliers=args.outliers)

        train_metrics.append(mt);  val_metrics.append(mv)
        train_fp_metrics.append(mt_fp); val_fp_metrics.append(mv_fp)

    train_metrics, val_metrics = truncate_metrics_to_consistent_length(train_metrics, val_metrics)
    train_fp_metrics, val_fp_metrics = truncate_metrics_to_consistent_length(train_fp_metrics, val_fp_metrics)

    train_metrics, val_metrics = normalize_and_stack_fix(train_metrics, val_metrics)
    train_metrics_fp, val_metrics_fp = normalize_and_stack_fix(train_fp_metrics, val_fp_metrics)

    # train_metrics,    val_metrics    = normalize_and_stack_fix(train_metrics,    val_metrics,    normalize=args.normalize)
    # train_metrics_fp, val_metrics_fp = normalize_and_stack_fix(train_fp_metrics, val_fp_metrics, normalize=args.normalize)

    

    if args.kernel == "poly":
        kernel = kernels.polynomial_kernel
        kernel_name = "poly"
    else:
        kernel = kernels.linear_kernel if hasattr(kernels, "linear_kernel") else kernels.polynomial_kernel
        kernel_name = "linear" if hasattr(kernels, "linear_kernel") else "poly"

    os.makedirs(args.output_dir, exist_ok=True)

    for i in range(args.num_random):
        np.random.shuffle(train_metrics)
        np.random.shuffle(val_metrics)
        np.random.shuffle(train_metrics_fp)
        np.random.shuffle(val_metrics_fp)

        num_samples = args.num_samples
        (train_x, train_y), (_, _) = get_dataset_splits(train_metrics, val_metrics, num_samples)
        (train_x_fp, train_y_fp), (_, _) = get_dataset_splits(train_metrics_fp, val_metrics_fp, num_samples)

        # Build streams from validation halves (what we "guess" on sequentially)
        X_stream_train, X_stream_val = build_stream_pairs_from_split(train_x, train_y)
        X_stream_train_fp, X_stream_val_fp = build_stream_pairs_from_split(train_x_fp, train_y_fp)

        cond_A = "ONLINE"
        S_tr_A, S_va_A = build_online_score_sequences(
            X_stream_train, X_stream_val,
            num_epochs=args.online_epochs,
            warmup_mode=args.warmup_mode,
            warmup_K=args.warmup_K
        )
        S_tr_A_fp, S_va_A_fp = build_online_score_sequences(
            X_stream_train_fp, X_stream_val_fp,
            num_epochs=args.online_epochs,
            warmup_mode=args.warmup_mode,
            warmup_K=args.warmup_K
        )

        Nmax_A  = min(len(S_tr_A), len(S_va_A))
        Nmax_Af = min(len(S_tr_A_fp), len(S_va_A_fp))
        if Nmax_A < MIN_MMD_POINTS:
            print(f"[WARN] ONLINE H1 has only {Nmax_A} samples (< {MIN_MMD_POINTS}); skipping ONLINE H1.")
            power_A, stopped_A, stop_times_A, W_A = [], np.array([], bool), [], []
        else:
            Source_A = source_from_sequences(S_tr_A, S_va_A)
            power_A, stopped_A, stop_times_A, W_A = runSequentialTest(
                Source=Source_A,
                Prediction=kernelMMDprediction,
                Betting=ONSstrategy,
                alpha=args.alpha,
                pred_params={"kernel": kernel, "post_processing": args.post_processing},
                bet_params={"lambda_max": args.lambda_max},
                Nmax=Nmax_A,
                num_trials=args.num_trials,
                progress_bar=True,
                hedge=False,
                return_wealth=True
            )

        if Nmax_Af < MIN_MMD_POINTS:
            print(f"[WARN] ONLINE FP has only {Nmax_Af} samples (< {MIN_MMD_POINTS}); skipping ONLINE FP.")
            power_A_fp, stopped_A_fp, stop_times_A_fp, W_A_fp = [], np.array([], bool), [], []
        else:
            Source_A_fp = source_from_sequences(S_tr_A_fp, S_va_A_fp)
            power_A_fp, stopped_A_fp, stop_times_A_fp, W_A_fp = runSequentialTest(
                Source=Source_A_fp,
                Prediction=kernelMMDprediction,
                Betting=ONSstrategy,
                alpha=args.alpha,
                pred_params={"kernel": kernel, "post_processing": args.post_processing},
                bet_params={"lambda_max": args.lambda_max},
                Nmax=Nmax_Af,
                num_trials=args.num_trials,
                progress_bar=True,
                hedge=False,
                return_wealth=True
            )

        prefix = args.save_file_prefix + ("_" if args.save_file_prefix else "")
        baseA_csv = f"./{args.output_dir}/{args.dataset_name}/{prefix}e_process_info_dataset_{args.dataset_name}_{cond_A}_seed_{i}.csv"

        save_sequential_results_csv(
            power_A, stopped_A, stop_times_A, W_A,
            power_A_fp, stopped_A_fp, stop_times_A_fp, W_A_fp,
            kernel_name, args.alpha, args.num_trials, baseA_csv
        )

    print("[DONE] All runs complete.")

if __name__ == "__main__":
    main()