import pickle
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import argparse
from scipy.spatial.distance import jensenshannon


def _is_joint(feat: str) -> bool:
    return "|" in feat


def _is_backbone(feat: str) -> bool:
    f = feat.lower()
    return any(tok in f for tok in ("phi", "psi", "omega")) and not _is_joint(feat)


def _is_sidechain(feat: str) -> bool:
    f = feat.lower()
    return "chi" in f and not _is_joint(feat)


def analyze_pickle(pickle_file, directory, name=None):
    # Construct full path for the pickle file
    file_path = os.path.join(directory, pickle_file)

    # Load the pickle file
    with open(file_path, "rb") as f:
        data = pickle.load(f)

    # Initialize containers for summary statistics
    summary_stats = {
        "Torsions (bb)": [],
        "Torsions (sc)": [],
        "Torsions (all)": [],
        "TICA-0": [],
        "TICA-0,1 joint": [],
        "Torsions (bb) KL": [],
        "Torsions (sc) KL": [],
        "Torsions (all) KL": [],
        "TICA-0 KL": [],
        "TICA-0,1 joint KL": [],
        "MSM states": [],
        "Flux Spearman": [],
        "Macrostate MAE": [],
        "Gyration Radius Difference": [],
        "Gyration Radius KL": [],
        "Gyration Radius JSD": [],
        "Secondary Structure Difference": [],
        "Secondary Structure KL": [],
        "Secondary Structure JSD": [],
    }
    Neff_stats = {
        "Ours Torsions (BB)": [],
        "Ours Torsions (SC)": [],
        "Ours Torsions (All)": [],
        "Ours TICA": [],
    }

    # Extract and compute statistics for JSD
    for pdb_id, results in data.items():
        jsd = results.get("JSD", {})
        for feat, value in jsd.items():
            if _is_backbone(feat):
                summary_stats["Torsions (bb)"].append(value)
                summary_stats["Torsions (all)"].append(value)
            elif _is_sidechain(feat):
                summary_stats["Torsions (sc)"].append(value)
                summary_stats["Torsions (all)"].append(value)

        # TICA statistics
        if "TICA-0" in jsd:
            summary_stats["TICA-0"].append(jsd["TICA-0"])
        if "TICA-0,1" in jsd:
            summary_stats["TICA-0,1 joint"].append(jsd["TICA-0,1"])

        # KL divergence statistics
        kl = results.get("FWD_KL", {})
        for feat, value in kl.items():
            if _is_backbone(feat):
                summary_stats["Torsions (bb) KL"].append(value)
                summary_stats["Torsions (all) KL"].append(value)
            elif _is_sidechain(feat):
                summary_stats["Torsions (sc) KL"].append(value)
                summary_stats["Torsions (all) KL"].append(value)

        if "TICA-0" in kl:
            summary_stats["TICA-0 KL"].append(kl["TICA-0"])
        if "TICA-0,1" in kl:
            summary_stats["TICA-0,1 joint KL"].append(kl["TICA-0,1"])

        our_neff = results.get("our_Neff", {})
        for feat, val in our_neff.items():
            feat_lower = feat.lower()
            if "phi" in feat_lower or "psi" in feat_lower:
                Neff_stats["Ours Torsions (BB)"].append(val)
                Neff_stats["Ours Torsions (All)"].append(val)
            elif "tica" in feat_lower:
                Neff_stats["Ours TICA"].append(val)
            else:
                Neff_stats["Ours Torsions (SC)"].append(val)
                Neff_stats["Ours Torsions (All)"].append(val)

        # MSM state occupancy comparison
        if "msm_pi" in results and "traj_pi" in results:
            ref_pi = np.array(results["msm_pi"])
            traj_pi = np.array(results["traj_pi"])
            summary_stats["MSM states"].append(jensenshannon(ref_pi, traj_pi))

        # Flux Spearman correlation
        if "flux_spearman_corr" in results:
            summary_stats["Flux Spearman"].append(results["flux_spearman_corr"])
        if "mMAE" in results:
            summary_stats["Macrostate MAE"].append(results["mMAE"])

        # New metrics for gyration radius divergence
        if "gyration_radius_difference" in results:
            summary_stats["Gyration Radius Difference"].append(
                results["gyration_radius_difference"]
            )
        if "gyration_radius_KL" in results:
            summary_stats["Gyration Radius KL"].append(results["gyration_radius_KL"])
        elif "forward_kl_divergence" in results:
            summary_stats["Gyration Radius KL"].append(results["forward_kl_divergence"])
        if "gyration_radius_JSD" in results:
            summary_stats["Gyration Radius JSD"].append(results["gyration_radius_JSD"])
        elif "jensen_shannon_divergence" in results:
            summary_stats["Gyration Radius JSD"].append(
                results["jensen_shannon_divergence"]
            )

        # New metrics for secondary structure divergence
        if "ss_difference" in results:
            summary_stats["Secondary Structure Difference"].append(
                results["ss_difference"]
            )
        if "ss_KL" in results:
            summary_stats["Secondary Structure KL"].append(results["ss_KL"])
        elif "forward_kl_divergence" in results:
            summary_stats["Secondary Structure KL"].append(results["ss_KL"])
        if "ss_JSD" in results:
            summary_stats["Secondary Structure JSD"].append(results["ss_JSD"])
        elif "jensen_shannon_divergence" in results:
            summary_stats["Secondary Structure JSD"].append(results["ss_JSD"])

    summary_df = pd.DataFrame(
        {key: [np.mean(vals), np.std(vals)] for key, vals in summary_stats.items()},
        index=["Mean", "Std"],
    )
    print(summary_df)

    if name is None:
        csv_path = os.path.join(directory, "aa_summary_stats.csv")
    else:
        csv_path = os.path.join(directory, f"aa_summary_stats_{name}.csv")
    summary_df.to_csv(csv_path)
    print(f"Summary statistics saved to {csv_path}")

    # Summarize Neff
    print("Neff_stats", Neff_stats)
    neff_summary_df = pd.DataFrame(
        {
            key: [np.mean(vals), np.std(vals)]
            for key, vals in Neff_stats.items()
            if len(vals) > 0
        },
        index=["Mean", "Std"],
    )
    print(neff_summary_df)
    if name is None:
        neff_csv_path = os.path.join(directory, "aa_Neff_summary.csv")
    else:
        neff_csv_path = os.path.join(directory, f"aa_Neff_summary_{name}.csv")
    neff_summary_df.to_csv(neff_csv_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Analyze a pickle file and generate summary statistics and plots."
    )
    parser.add_argument(
        "--file",
        type=str,
        default="out.pkl",
        help="Name of the pickle file to analyze.",
    )
    parser.add_argument(
        "--dir",
        type=str,
        default="workdir/",
        help="Directory containing the pickle file and where results will be saved.",
    )
    parser.add_argument(
        "--name", type=str, default=None, help="Name of the output CSV file."
    )

    args = parser.parse_args()

    analyze_pickle(args.file, args.dir, args.name)
