#!/usr/bin/env python
"""
Combined Analysis Script

For each provided root directory, this script recursively searches for:
  - An alphaFlow pickle file (default name "out.pkl")
  - A basic analysis pickle file (default name "out_basic.pkl")

For every directory that contains at least one of these files, the script runs the
corresponding analysis functions. If one type of file is missing from a directory,
its metrics will be filled with NaNs. Finally, all results are combined into one CSV.

Usage:
    python combined_analysis.py <root_dir1> [<root_dir2> ...]
           [--alpha out.pkl] [--basic out_basic.pkl] [--output combined_analysis.csv]
"""

import os
import sys
import argparse
import pickle
import numpy as np
import pandas as pd
import scipy.stats
import warnings


# ------------------------------- #
# Analysis function for out.pkl   #
# ------------------------------- #
def analyze_alpha_flow(pkl_path):
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)

    df_list = []
    for name, out in data.items():
        item = {
            "md_pairwise": out.get("ref_mean_pairwise_rmsd", np.nan),
            "af_pairwise": out.get("af_mean_pairwise_rmsd", np.nan),
            "cosine_sim": abs(out.get("cosine_sim", np.nan)),
            "emd_mean": np.square(out.get("emd_mean", np.nan)).mean() ** 0.5,
            "emd_var": np.square(out.get("emd_var", np.nan)).mean() ** 0.5,
            "jsd_rmsf": out.get("jsd_rmsf", np.nan),
            "jsd_pairwise_rmsd": out.get("jsd_pairwise_rmsd", np.nan),
        }
        # Compute RMSF correlations
        try:
            pearson, _ = scipy.stats.pearsonr(out["af_rmsf"], out["ref_rmsf"])
            spearman, _ = scipy.stats.spearmanr(out["af_rmsf"], out["ref_rmsf"])
            kendall, _ = scipy.stats.kendalltau(out["af_rmsf"], out["ref_rmsf"])
        except Exception:
            pearson = spearman = kendall = np.nan
        item.update(
            {
                "rmsf_pearson": pearson,
                "rmsf_spearman": spearman,
                "rmsf_kendall": kendall,
            }
        )

        # Use alternative keys if needed
        if "EMD,ref" not in out:
            out["EMD,ref"] = out.get("EMD-2,ref", {})
            out["EMD,af2"] = out.get("EMD-2,af2", {})
            out["EMD,joint"] = out.get("EMD-2,joint", {})
        for emd_dict, emd_key in [
            (out.get("EMD,ref", {}), "ref"),
            (out.get("EMD,joint", {}), "joint"),
        ]:
            try:
                item.update(
                    {
                        emd_key + "emd": emd_dict["ref|af"],
                        emd_key + "emd_tr": emd_dict["ref mean|af mean"],
                        emd_key
                        + "emd_int": (
                            (
                                emd_dict["ref|af"] ** 2
                                - emd_dict["ref mean|af mean"] ** 2
                            )
                        )
                        ** 0.5,
                    }
                )
            except Exception:
                item.update(
                    {
                        emd_key + "emd": np.nan,
                        emd_key + "emd_tr": np.nan,
                        emd_key + "emd_int": np.nan,
                    }
                )

        # Contact metrics
        try:
            crystal_contact_mask = out["crystal_distmat"] < 0.8
            ref_transient_mask = (~crystal_contact_mask) & (
                out["ref_contact_prob"] > 0.1
            )
            af_transient_mask = (~crystal_contact_mask) & (out["af_contact_prob"] > 0.1)
            ref_weak_mask = crystal_contact_mask & (out["ref_contact_prob"] < 0.9)
            af_weak_mask = crystal_contact_mask & (out["af_contact_prob"] < 0.9)
            item.update(
                {
                    "weak_contacts_iou": (ref_weak_mask & af_weak_mask).sum()
                    / ((ref_weak_mask | af_weak_mask).sum() or np.nan),
                    "transient_contacts_iou": (
                        ref_transient_mask & af_transient_mask
                    ).sum()
                    / ((ref_transient_mask | af_transient_mask).sum() or np.nan),
                }
            )
        except Exception:
            item.update(
                {
                    "weak_contacts_iou": np.nan,
                    "transient_contacts_iou": np.nan,
                }
            )

        # Solvent accessibility overlap
        sasa_thresh = 0.02
        buried_mask = (
            out["crystal_sasa"][0] < sasa_thresh
            if "crystal_sasa" in out and len(out["crystal_sasa"]) > 0
            else np.nan
        )
        if not np.isnan(buried_mask).any():
            ref_sa_mask = (out["ref_sa_prob"] > 0.1) & buried_mask
            af_sa_mask = (out["af_sa_prob"] > 0.1) & buried_mask
            item.update(
                {
                    "num_sasa": ref_sa_mask.sum(),
                    "sasa_iou": (ref_sa_mask & af_sa_mask).sum()
                    / ((ref_sa_mask | af_sa_mask).sum() or np.nan),
                }
            )
        else:
            item.update(
                {
                    "num_sasa": np.nan,
                    "sasa_iou": np.nan,
                }
            )

        # Mutual information (MI) correlations
        try:
            af_mi = out["af_mi_mat"].flatten()
            ref_mi = out["ref_mi_mat"].flatten()
            pearson_mi, _ = scipy.stats.pearsonr(ref_mi, af_mi)
            spearman_mi, _ = scipy.stats.spearmanr(ref_mi, af_mi)
            kendall_mi, _ = scipy.stats.kendalltau(ref_mi, af_mi)
        except Exception:
            pearson_mi = spearman_mi = kendall_mi = np.nan
        item.update(
            {
                "exposon_mi_pearson": pearson_mi,
                "exposon_mi_spearman": spearman_mi,
                "exposon_mi_kendall": kendall_mi,
            }
        )
        df_list.append(item)

    df = pd.DataFrame(df_list)
    result = {}
    result["count"] = len(df)
    if len(df) > 0:
        result["MD_pairwise_RMSD"] = df["md_pairwise"].median()
        result["Pairwise_RMSD"] = df["af_pairwise"].median()
        result["Pairwise_RMSD_r"] = scipy.stats.pearsonr(
            df["md_pairwise"], df["af_pairwise"]
        )[0]
        try:
            all_ref_rmsf = np.concatenate([data[k]["ref_rmsf"] for k in data.keys()])
            all_af_rmsf = np.concatenate([data[k]["af_rmsf"] for k in data.keys()])
            result["MD_RMSF"] = np.median(all_ref_rmsf)
            result["RMSF"] = np.median(all_af_rmsf)
            result["Global_RMSF_r"] = scipy.stats.pearsonr(all_ref_rmsf, all_af_rmsf)[0]
        except Exception:
            result["MD_RMSF"] = result["RMSF"] = result["Global_RMSF_r"] = np.nan

        result["Per_target_RMSF_r"] = df["rmsf_pearson"].median()
        result["RMWD"] = np.sqrt((df["emd_mean"] ** 2 + df["emd_var"] ** 2)).median()
        result["RMWD_trans"] = df["emd_mean"].median()
        result["RMWD_var"] = df["emd_var"].median()
        # Optional PCA metrics (if available)
        result["MD_PCA_W2"] = (
            df["refemd"].median() if "refemd" in df.columns else np.nan
        )
        result["Joint_PCA_W2"] = (
            df["jointemd"].median() if "jointemd" in df.columns else np.nan
        )
        result["PC_sim_gt_0.5_percent"] = (df["cosine_sim"] > 0.5).mean() * 100
        result["Weak_contacts_J"] = df["weak_contacts_iou"].median()
        result["Weak_contacts_nans"] = df["weak_contacts_iou"].isna().mean()
        result["Transient_contacts_J"] = df["transient_contacts_iou"].median()
        result["Transient_contacts_nans"] = df["transient_contacts_iou"].isna().mean()
        result["Exposed_residue_J"] = df["sasa_iou"].median()
        result["Exposed_MI_matrix_rho"] = df["exposon_mi_spearman"].median()
        print("df['jsd_rmsf']", df["jsd_rmsf"])
        result["RMSF_JSD"] = df["jsd_rmsf"].mean()
        print("df['jsd_pairwise_rmsd']", df["jsd_pairwise_rmsd"])
        result["Pairwise_RMSD_JSD"] = df["jsd_pairwise_rmsd"].mean()
    else:
        # Fill with NaN if no data was found
        alpha_keys = [
            "MD_pairwise_RMSD",
            "Pairwise_RMSD",
            "Pairwise_RMSD_r",
            "MD_RMSF",
            "RMSF",
            "Global_RMSF_r",
            "Per_target_RMSF_r",
            "RMWD",
            "RMWD_trans",
            "RMWD_var",
            "MD_PCA_W2",
            "Joint_PCA_W2",
            "PC_sim_gt_0.5_percent",
            "Weak_contacts_J",
            "Weak_contacts_nans",
            "Transient_contacts_J",
            "Transient_contacts_nans",
            "Exposed_residue_J",
            "Exposed_MI_matrix_rho",
            "RMSF_JSD",
            "Pairwise_RMSD_JSD",
        ]
        for key in alpha_keys:
            result[key] = np.nan

    return result


# ------------------------------------- #
# Analysis function for out_basic.pkl   #
# ------------------------------------- #
def analyze_basic(pkl_path):
    """
    Reads *out_basic.pkl* and summarises per‑target metrics, **including
    the new MSM feature divergences (gr/secondary/RMSD)**.

    All new columns begin with the prefix “msm ” so they’re easy to spot.
    """
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)

    # --- metrics we average across targets ---------------------------------
    summary_stats = {
        "Torsions (bb)": [],
        "Torsions (sc)": [],
        "Torsions (all)": [],
        "TICA-0": [],
        "TICA-0,1 joint": [],
        "Flux Spearman": [],
        "Macrostate MAE": [],
        "Gyration Radius Difference": [],
        "Gyration Radius KL": [],
        "Gyration Radius JSD": [],
        "Secondary Structure Difference": [],
        "Secondary Structure KL": [],
        "Secondary Structure JSD": [],
        # >>> NEW msm‑prefixed feature‑MSM entries
        "msm gr KL": [],
        "msm gr JSD": [],
        "msm 2nd KL": [],
        "msm 2nd JSD": [],
        "msm rmsd KL": [],
        "msm rmsd JSD": [],
        "msm gr 2nd KL": [],
        "msm gr 2nd JSD": [],
        "msm gr 2nd rmsd KL": [],
        "msm gr 2nd rmsd JSD": [],
    }
    # --- Neff buckets (unchanged) ------------------------------------------
    neff_stats = {
        "Ours Torsions (BB)": [],
        "Ours Torsions (SC)": [],
        "Ours Torsions (All)": [],
        "Ours TICA": [],
    }

    # ------------ iterate over individual targets --------------------------
    for pdb_id, results in data.items():
        jsd = results.get("JSD", {})

        # torsion JSD
        for feat, value in jsd.items():
            if "PHI" in feat or "PSI" in feat:
                summary_stats["Torsions (bb)"].append(value)
            else:
                summary_stats["Torsions (sc)"].append(value)
        summary_stats["Torsions (all)"].extend(jsd.values())

        # Neff
        our_neff = results.get("our_Neff", {})
        for feat, val in our_neff.items():
            feat_lower = feat.lower()
            if "phi" in feat_lower or "psi" in feat_lower:
                neff_stats["Ours Torsions (BB)"].append(val)
                neff_stats["Ours Torsions (All)"].append(val)
            elif "tica" in feat_lower:
                neff_stats["Ours TICA"].append(val)
            else:
                neff_stats["Ours Torsions (SC)"].append(val)
                neff_stats["Ours Torsions (All)"].append(val)

        # other existing metrics
        if "TICA-0" in jsd:
            summary_stats["TICA-0"].append(jsd["TICA-0"])
        if "TICA-0,1" in jsd:
            summary_stats["TICA-0,1 joint"].append(jsd["TICA-0,1"])
        if "flux_spearman_corr" in results:
            summary_stats["Flux Spearman"].append(results["flux_spearman_corr"])
        if "mMAE" in results:
            summary_stats["Macrostate MAE"].append(results["mMAE"])

        # gyration & secondary structure metrics
        for key_old, key_new in [
            ("gyration_radius_difference", "Gyration Radius Difference"),
            ("gyration_radius_KL", "Gyration Radius KL"),
            ("gyration_radius_JSD", "Gyration Radius JSD"),
            ("ss_difference", "Secondary Structure Difference"),
            ("ss_KL", "Secondary Structure KL"),
            ("ss_JSD", "Secondary Structure JSD"),
        ]:
            if key_old in results:
                summary_stats[key_new].append(results[key_old])

        # >>> MSM feature divergences
        feat_msm = results.get("feature_MSM", {})
        rename = {
            "gr_KL": "msm gr KL",
            "gr_JSD": "msm gr JSD",
            "secondary_KL": "msm 2nd KL",
            "secondary_JSD": "msm 2nd JSD",
            "rmsd_KL": "msm rmsd KL",
            "rmsd_JSD": "msm rmsd JSD",
            "gr,secondary_KL": "msm gr 2nd KL",
            "gr,secondary_JSD": "msm gr 2nd JSD",
            "gr,secondary,rmsd_KL": "msm gr 2nd rmsd KL",
            "gr,secondary,rmsd_JSD": "msm gr 2nd rmsd JSD",
        }
        for raw_key, nice_key in rename.items():
            if raw_key in feat_msm:
                summary_stats[nice_key].append(feat_msm[raw_key])

    # --------- turn lists into mean / std summary dict ----------------------
    result = {}
    for key, vals in summary_stats.items():
        result[f"{key}_mean"] = np.mean(vals) if vals else np.nan
        result[f"{key}_std"] = np.std(vals) if vals else np.nan
    for key, vals in neff_stats.items():
        result[f"{key}_mean"] = np.mean(vals) if vals else np.nan
        result[f"{key}_std"] = np.std(vals) if vals else np.nan

    return result


# ------------------------------------------ #
# Helper function: find all matching files   #
# ------------------------------------------ #
def find_all_files(root, target_filename):
    matches = []
    for dirpath, _, filenames in os.walk(root):
        for f in filenames:
            if f == target_filename:
                matches.append(os.path.join(dirpath, f))
    return matches


# ----------------------------------------------------- #
# Main: Build a combined dictionary grouped by directory #
# ----------------------------------------------------- #
def main():
    parser = argparse.ArgumentParser(
        description="Combined analysis of alphaFlow and basic pickle files."
    )
    parser.add_argument(
        "roots",
        nargs="+",
        help="Root directories to search recursively for pickle files.",
    )
    parser.add_argument(
        "--alpha",
        default="out.pkl",
        help="Filename for alphaFlow analysis (default: out.pkl)",
    )
    parser.add_argument(
        "--basic",
        default="out_basic.pkl",
        help="Filename for basic analysis (default: out_basic.pkl)",
    )
    parser.add_argument(
        "--output", default="combined_analysis.csv", help="Output CSV file name."
    )
    args = parser.parse_args()

    # Dictionaries to map parent directories to file paths.
    alpha_files = {}
    basic_files = {}

    # Loop over provided root directories and collect all matching files.
    for root in args.roots:
        alpha_matches = find_all_files(root, args.alpha)
        basic_matches = find_all_files(root, args.basic)

        for path in alpha_matches:
            parent = os.path.dirname(path)
            if parent in alpha_files:
                print(
                    f"Warning: Multiple alpha files found in {parent}. Using the first one ({alpha_files[parent]})"
                )
            else:
                alpha_files[parent] = path

        for path in basic_matches:
            parent = os.path.dirname(path)
            if parent in basic_files:
                print(
                    f"Warning: Multiple basic files found in {parent}. Using the first one ({basic_files[parent]})"
                )
            else:
                basic_files[parent] = path

    # Create union of all directories that have at least one file.
    all_dirs = set(alpha_files.keys()).union(set(basic_files.keys()))

    results = []
    for directory in sorted(all_dirs):
        row = {"directory": directory}
        # Process the alpha file if available.
        if directory in alpha_files:
            try:
                alpha_res = analyze_alpha_flow(alpha_files[directory])
            except Exception as e:
                print(f"Error analyzing alpha file at {alpha_files[directory]}: {e}")
                alpha_res = {}
                # Fill expected keys with NaN
                for key in [
                    "count",
                    "MD_pairwise_RMSD",
                    "Pairwise_RMSD",
                    "Pairwise_RMSD_r",
                    "MD_RMSF",
                    "RMSF",
                    "Global_RMSF_r",
                    "Per_target_RMSF_r",
                    "RMWD",
                    "RMWD_trans",
                    "RMWD_var",
                    "MD_PCA_W2",
                    "Joint_PCA_W2",
                    "PC_sim_gt_0.5_percent",
                    "Weak_contacts_J",
                    "Weak_contacts_nans",
                    "Transient_contacts_J",
                    "Transient_contacts_nans",
                    "Exposed_residue_J",
                    "Exposed_MI_matrix_rho",
                    "RMSF_JSD",
                    "Pairwise_RMSD_JSD",
                ]:
                    alpha_res.setdefault(key, np.nan)
            row.update(alpha_res)
        else:
            print(f"Alpha file not found in directory: {directory}")
            row.update(
                {
                    "count": np.nan,
                    "MD_pairwise_RMSD": np.nan,
                    "Pairwise_RMSD": np.nan,
                    "Pairwise_RMSD_r": np.nan,
                    "MD_RMSF": np.nan,
                    "RMSF": np.nan,
                    "Global_RMSF_r": np.nan,
                    "Per_target_RMSF_r": np.nan,
                    "RMWD": np.nan,
                    "RMWD_trans": np.nan,
                    "RMWD_var": np.nan,
                    "MD_PCA_W2": np.nan,
                    "Joint_PCA_W2": np.nan,
                    "PC_sim_gt_0.5_percent": np.nan,
                    "Weak_contacts_J": np.nan,
                    "Weak_contacts_nans": np.nan,
                    "Transient_contacts_J": np.nan,
                    "Transient_contacts_nans": np.nan,
                    "Exposed_residue_J": np.nan,
                    "Exposed_MI_matrix_rho": np.nan,
                    "RMSF_JSD": np.nan,
                    "Pairwise_RMSD_JSD": np.nan,
                }
            )

        # Process the basic file if available.
        if directory in basic_files:
            try:
                basic_res = analyze_basic(basic_files[directory])
            except Exception as e:
                print(f"Error analyzing basic file at {basic_files[directory]}: {e}")
                basic_res = {}
                for key in [
                    "Torsions (bb)",
                    "Torsions (sc)",
                    "Torsions (all)",
                    "TICA-0",
                    "TICA-0,1 joint",
                    "Flux Spearman",
                    "Macrostate MAE",
                    "Gyration Radius Difference",
                    "Gyration Radius KL",
                    "Gyration Radius JSD",
                    "Secondary Structure Difference",
                    "Secondary Structure KL",
                    "Secondary Structure JSD",
                    "Ours Torsions (BB)",
                    "Ours Torsions (SC)",
                    "Ours Torsions (All)",
                    "Ours TICA",
                    "msm gr KL",
                    "msm gr JSD",
                    "msm 2nd KL",
                    "msm 2nd JSD",
                    "msm rmsd KL",
                    "msm rmsd JSD",
                    "msm gr 2nd KL",
                    "msm gr 2nd JSD",
                    "msm gr 2nd rmsd KL",
                    "msm gr 2nd rmsd JSD",
                ]:
                    basic_res.setdefault(f"basic_{key}_mean", np.nan)
                    basic_res.setdefault(f"basic_{key}_std", np.nan)
            row.update(basic_res)
        else:
            print(f"Basic file not found in directory: {directory}")
            row.update(
                {
                    "Torsions (bb)_mean": np.nan,
                    "Torsions (bb)_std": np.nan,
                    "Torsions (sc)_mean": np.nan,
                    "Torsions (sc)_std": np.nan,
                    "Torsions (all)_mean": np.nan,
                    "Torsions (all)_std": np.nan,
                    "TICA-0_mean": np.nan,
                    "TICA-0_std": np.nan,
                    "TICA-0,1 joint_mean": np.nan,
                    "TICA-0,1 joint_std": np.nan,
                    "Flux Spearman_mean": np.nan,
                    "Flux Spearman_std": np.nan,
                    "Macrostate MAE_mean": np.nan,
                    "Macrostate MAE_std": np.nan,
                    "Gyration Radius Difference_mean": np.nan,
                    "Gyration Radius Difference_std": np.nan,
                    "Gyration Radius KL_mean": np.nan,
                    "Gyration Radius KL_std": np.nan,
                    "Gyration Radius JSD_mean": np.nan,
                    "Gyration Radius JSD_std": np.nan,
                    "Secondary Structure Difference_mean": np.nan,
                    "Secondary Structure Difference_std": np.nan,
                    "Secondary Structure KL_mean": np.nan,
                    "Secondary Structure KL_std": np.nan,
                    "Secondary Structure JSD_mean": np.nan,
                    "Secondary Structure JSD_std": np.nan,
                    "Ours Torsions (BB)_mean": np.nan,
                    "Ours Torsions (BB)_std": np.nan,
                    "Ours Torsions (SC)_mean": np.nan,
                    "Ours Torsions (SC)_std": np.nan,
                    "Ours Torsions (All)_mean": np.nan,
                    "Ours Torsions (All)_std": np.nan,
                    "Ours TICA_mean": np.nan,
                    "Ours TICA_std": np.nan,
                    "msm gr KL_mean": np.nan,
                    "msm gr KL_std": np.nan,
                    "msm gr JSD_mean": np.nan,
                    "msm gr JSD_std": np.nan,
                    "msm 2nd KL_mean": np.nan,
                    "msm 2nd KL_std": np.nan,
                    "msm 2nd JSD_mean": np.nan,
                    "msm 2nd JSD_std": np.nan,
                    "msm rmsd KL_mean": np.nan,
                    "msm rmsd KL_std": np.nan,
                    "msm rmsd JSD_mean": np.nan,
                    "msm rmsd JSD_std": np.nan,
                    "msm gr 2nd KL_mean": np.nan,
                    "msm gr 2nd KL_std": np.nan,
                    "msm gr 2nd JSD_mean": np.nan,
                    "msm gr 2nd JSD_std": np.nan,
                    "msm gr 2nd rmsd KL_mean": np.nan,
                    "msm gr 2nd rmsd KL_std": np.nan,
                    "msm gr 2nd rmsd JSD_mean": np.nan,
                    "msm gr 2nd rmsd JSD_std": np.nan,
                }
            )

        results.append(row)

    df_final = pd.DataFrame(results)
    df_final.set_index("directory", inplace=True)
    df_final.to_csv(args.output, float_format="%.3f")
    print(f"Combined analysis saved to {args.output}")


if __name__ == "__main__":
    main()
