# experiments/verify_by_ablation.py
import argparse
import os
import joblib
import numpy as np
import pandas as pd
from tqdm import tqdm

# --- Your project's imports (reuse from other scripts) ---
from downstream_feature_importance import (
    load_single_pred_dataset,
    _find_smiles_column,
    build_feature_sets,
    infer_task_type_from_y,
    prepare_labels_for_classification,
    concat_sae_across_layers,
    generate_ecfp_features

)
from lmkit.sparse.sae import SAEKit


def parse_feature_str(feature_str: str) -> list[tuple[int, int]]:
    """Parses a string like 'L3_F1916,L4_F1125' into a list of tuples."""
    features = []
    for part in feature_str.split(","):
        part = part.strip()
        if not part.startswith("L") or "_F" not in part:
            raise ValueError(f"Invalid feature format: {part}")
        layer_str, feature_idx_str = part.split("_F")
        layer_id = int(layer_str[1:])
        feature_idx = int(feature_idx_str)
        features.append((layer_id, feature_idx))
    return features


def map_features_to_global_indices(
    features_to_ablate: list[tuple[int, int]], sae_layers: list[int], sae_kit: SAEKit
) -> list[int]:
    """Maps (layer_id, feature_in_layer) to global column indices."""
    layer_dims = [sae_kit.sae_configs[L].latent_size for L in sae_layers]
    layer_offsets = np.cumsum([0] + layer_dims)

    global_indices = []
    for layer_id, feature_idx in features_to_ablate:
        if layer_id not in sae_layers:
            print(
                f"[Warning] Layer {layer_id} for feature to ablate is not in the model's SAE layers. Skipping."
            )
            continue

        layer_list_idx = sae_layers.index(layer_id)
        global_idx = layer_offsets[layer_list_idx] + feature_idx
        global_indices.append(global_idx)

    return global_indices


def run_ablation(args):
    features_to_ablate = parse_feature_str(args.ablate_features)
    print(f"Target features for ablation: {features_to_ablate}")

    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.ckpt_id, sae_dir=args.sae_dir
    )
    sae_layers = [int(x) for x in args.sae_layers.split(",")]
    seeds = [int(s) for s in args.seeds.split(",") if s.strip()]
    tasks = [t.strip() for t in args.tasks.split(",") if t.strip()]

    global_indices_to_ablate = map_features_to_global_indices(
        features_to_ablate, sae_layers, sae_kit
    )
    if not global_indices_to_ablate:
        print("No valid features to ablate. Exiting.")
        return
    print(f"Mapped to global column indices: {global_indices_to_ablate}")

    os.makedirs(args.out_dir, exist_ok=True)

    for task_name in tasks:
        print(f"\n--- Processing Task: {task_name} ---")
        _, ds = load_single_pred_dataset(task_name)

        for seed in seeds:
            print(f"  - Seed: {seed}")
            model_path = os.path.join(
                args.model_load_dir,
                f"{task_name}_{args.feature_set.replace(' ', '')}_seed{seed}.joblib",
            )
            if not os.path.exists(model_path):
                print(f"    Model not found at {model_path}, skipping.")
                continue

            model = joblib.load(model_path)
            split = ds.get_split(method=args.split_method, seed=seed)
            smi_col = _find_smiles_column(split["train"])
            test_df = split["test"]

            y_true = test_df["Y"].values
            task_kind = infer_task_type_from_y(y_true)

            # Generate features for the test set
            print("    Generating features for the test set...")
            test_smiles = test_df[smi_col].tolist()

            # Generate the base SAE features which are always needed
            _, sae_test_features = concat_sae_across_layers(
                test_smiles, sae_kit, sae_layers, batch_size=args.batch_size
            )

            # Construct the final feature matrix based on the specified feature set
            if args.feature_set == "SAE Features":
                X_test = sae_test_features
            elif args.feature_set == "SAE ⊕ ECFP":
                ecfp_test_features = generate_ecfp_features(
                    test_smiles, n_bits=args.ecfp_bits
                )
                X_test = np.concatenate([sae_test_features, ecfp_test_features], axis=1)
            else:
                raise ValueError(
                    f"Unsupported feature_set for ablation: {args.feature_set}"
                )

            print(f"    Generated feature matrix of shape: {X_test.shape}")

            # 1. Get original predictions
            if task_kind == "classification":
                original_probs = model.predict_proba(X_test)
            else:
                original_preds = model.predict(X_test)

            # 2. Create ablated features and get new predictions
            X_ablated = X_test.copy()
            X_ablated[:, global_indices_to_ablate] = 0.0

            if task_kind == "classification":
                ablated_probs = model.predict_proba(X_ablated)
            else:
                ablated_preds = model.predict(X_ablated)

            # 3. Calculate performance drop for each molecule
            results = {"SMILES": test_df[smi_col].tolist(), "y_true": y_true}
            if task_kind == "classification":
                # For classification, we need the encoded labels to correctly index probabilities
                _, _, y_test_enc, classes = prepare_labels_for_classification(
                    [], [], y_true
                )

                prob_original_correct_class = original_probs[
                    np.arange(len(y_test_enc)), y_test_enc
                ]
                prob_ablated_correct_class = ablated_probs[
                    np.arange(len(y_test_enc)), y_test_enc
                ]

                results["prob_drop"] = (
                    prob_original_correct_class - prob_ablated_correct_class
                )
                results["original_prob_correct"] = prob_original_correct_class
                results["ablated_prob_correct"] = prob_ablated_correct_class
                sort_by = "prob_drop"
            else:  # Regression
                error_original = np.abs(y_true - original_preds)
                error_ablated = np.abs(y_true - ablated_preds)

                results["error_increase"] = error_ablated - error_original
                results["original_pred"] = original_preds
                results["ablated_pred"] = ablated_preds
                sort_by = "error_increase"

            results_df = pd.DataFrame(results).sort_values(by=sort_by, ascending=False)

            # 4. Save results
            output_filename = f"ablation_impact_{task_name}_seed{seed}.csv"
            results_df.to_csv(os.path.join(args.out_dir, output_filename), index=False)
            print(
                f"    Saved ablation impact analysis to {os.path.join(args.out_dir, output_filename)}"
            )
            print("    Top 5 most affected molecules:")
            print(results_df.head(5).to_string())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        "Verify feature importance by ablating top features."
    )
    # Paths
    parser.add_argument("--model_dir", required=True)
    parser.add_argument("--ckpt_id", required=True)
    parser.add_argument("--sae_dir", required=True)
    parser.add_argument(
        "--model_load_dir",
        required=True,
        help="Directory where trained XGBoost models are saved.",
    )
    parser.add_argument("--out_dir", default="experiments/ablation_verification")
    # Feature & Data Specs
    parser.add_argument(
        "--sae_layers",
        required=True,
        help="Comma-sep SAE layers used in the saved model.",
    )
    parser.add_argument("--tasks", required=True, help="Comma-sep list of TDC tasks.")
    parser.add_argument("--seeds", default="1,2,3", help="Comma-sep seeds.")
    parser.add_argument("--split_method", default="scaffold")
    parser.add_argument("--ecfp_bits", type=int, default=2048)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument(
        "--feature_set",
        default="SAE ⊕ ECFP",
        help="Feature set the model was trained on.",
    )
    # Ablation Target
    parser.add_argument(
        "--ablate_features",
        required=True,
        help="Comma-sep list of features to ablate, e.g., 'L5_F145,L3_F1916'",
    )

    args = parser.parse_args()
    run_ablation(args)
