import os
import numpy as np
import tqdm
import deeptime
import torch
import random
from collections import Counter
import mdtraj


from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument(
    "--data_dir",
    type=str,
    default="anonymous",
)
parser.add_argument(
    "--cluster_data_dir",
    type=str,
    default="anonymous",
)
parser.add_argument("--input_file", type=str, default="./splits/mdCATH.txt")
parser.add_argument(
    "--error_domains_file",
    type=str,
    default="anonymous",
)
parser.add_argument("--lag_msm", type=int, default=1)
parser.add_argument("--temp", type=int, default=320)
parser.add_argument(
    "--features",
    type=str,
    default="gr,rmsd,secondary",
    help="Comma separated features to use, options: gr, rmsd, secondary",
)
parser.add_argument(
    "--fraction",
    type=str,
    default=None,
    help="Specify fraction as 'numerator/denominator', e.g., '3/8'",
)
parser.add_argument(
    "--n_clusters",
    type=int,
    default=5,
    help="Number of clusters.",
)
parser.add_argument(
    "--count_processed", action="store_true", help="Count processed domains"
)
args = parser.parse_args()


def parse_fraction(fraction_str):
    numerator, denominator = map(int, fraction_str.split("/"))
    start_frac = (numerator - 1) / denominator
    end_frac = numerator / denominator
    return start_frac, end_frac


def load_trajectories(data_dir, name, temp=320):
    """
    Load the 5 trajectories for a domain.
    """
    trajectories = []
    top_path = os.path.join(data_dir, "topology", f"{name}.pdb")
    for i in range(5):
        traj_path = os.path.join(data_dir, "trajectory", f"{name}_{temp}_{i}.xtc")
        traj = mdtraj.load_xtc(traj_path, top=top_path)
        trajectories.append(traj)
    return trajectories


def compute_feature(traj, feature, global_reference):
    """
    Compute the given observable for a trajectory.
    Returns a (n_frames, 1) array.
    """
    if feature == "gr":
        # Compute radius of gyration and reshape to column vector
        return mdtraj.compute_rg(traj).reshape(-1, 1)
    elif feature == "secondary":
        dssp = mdtraj.compute_dssp(traj)
        secondary_codes = ["H", "G", "I", "E", "B"]
        counts = np.sum(np.isin(dssp, secondary_codes), axis=1) / dssp.shape[1]
        return counts.reshape(-1, 1)
    elif feature == "rmsd":
        rmsd_vals = mdtraj.rmsd(traj, global_reference)
        return rmsd_vals.reshape(-1, 1)
    else:
        raise ValueError(f"Unknown feature: {feature}")


def get_combined_features(trajectories, features):
    """
    For each trajectory, compute each feature and concatenate them horizontally.
    If RMSD is requested, use the first frame of the first trajectory as a global reference.
    """
    global_reference = trajectories[0][0]

    combined_features = []
    for traj in trajectories:
        feature_list = []
        for feat in features:
            feat_array = compute_feature(traj, feat, global_reference)
            feature_list.append(feat_array)
        # Concatenate features for each frame (resulting shape: n_frames x n_features)
        combined = np.hstack(feature_list)
        combined_features.append(combined)
    return combined_features


def get_assignments(feature_list, n_clusters):
    """
    Use k-means clustering (via deeptime) on the list of feature arrays.
    Returns a list of discrete trajectory assignments.
    """
    kmeans_model = deeptime.clustering.KMeans(
        n_clusters=n_clusters, max_iter=100, fixed_seed=2137
    ).fit(feature_list)
    discrete_trajs = [kmeans_model.transform(feat) for feat in feature_list]
    return discrete_trajs


def print_cluster_stats(all_features, clusters, feature_names):
    """
    Print simple statistics (mean and std) for each cluster
    for the first feature of the combined feature vector.
    Extend this function if you need detailed stats for each observable.
    """
    # Here we demonstrate statistics for the first column (first feature)
    combined_feature = np.concatenate([feat[:, 0] for feat in all_features])
    cluster_assignments = np.concatenate(clusters)
    unique_clusters = np.unique(cluster_assignments)
    for cluster in unique_clusters:
        mask = cluster_assignments == cluster
        vals = combined_feature[mask]
        mean_val = np.mean(vals)
        std_val = np.std(vals)
        print(
            f"Cluster {cluster} (based on {feature_names[0]}): n={len(vals)}, mean={mean_val:.3f}, std={std_val:.3f}"
        )


def check_and_load_files(cluster_data_dir, name, n_clusters, temp, feature_str):
    """
    Check for pre-existing files.
    Filenames are based on the selected feature combination.
    """
    transition_file = os.path.join(
        cluster_data_dir,
        f"{name}_transition_matrix_{feature_str}_{n_clusters}_{temp}.npy",
    )
    transition_exists = os.path.exists(transition_file)
    assignment_files = []
    assignment_exists = []
    for i in range(5):
        if n_clusters == 5:
            assignment_file = os.path.join(
                cluster_data_dir, f"{name}_{feature_str}_{temp}_{i}.npy"
            )
        else:
            assignment_file = os.path.join(
                cluster_data_dir, f"{name}_{feature_str}_{n_clusters}_{temp}_{i}.npy"
            )
        assignment_files.append(assignment_file)
        assignment_exists.append(os.path.exists(assignment_file))
    return (
        transition_exists,
        np.all(assignment_exists),
        transition_file,
        assignment_files,
    )


def save_files(transition_matrix, transition_file):
    np.save(transition_file, transition_matrix)


def save_assignments(assignments, assignments_files):
    for i, file in enumerate(assignments_files):
        np.save(file, assignments[i])


def get_msm(trajs, lag=1):
    counts_estimator = deeptime.markov.TransitionCountEstimator(
        lag, count_mode="sliding"
    )
    counts = counts_estimator.fit_fetch(trajs)
    msm = deeptime.markov.msm.MaximumLikelihoodMSM(
        allow_disconnected=True, reversible=True
    ).fit_fetch(counts)
    return msm.transition_matrix


# --- Main Processing Functions ---


def preprocess(name, data_dir, cluster_data_dir, features):
    random.seed(2137)
    np.random.seed(2137)
    torch.manual_seed(2137)
    torch.cuda.manual_seed(2137)

    # Create a string that uniquely represents the combination of observables
    feature_str = "_".join(features)

    transition_exists, assignments_exists, transition_file, assignments_files = (
        check_and_load_files(
            cluster_data_dir, name, args.n_clusters, args.temp, feature_str
        )
    )

    # Load trajectories once
    trajectories = load_trajectories(data_dir, name, args.temp)
    # Compute the combined features for each trajectory
    combined_features = get_combined_features(trajectories, features)

    # --- Standardization added here ---
    # Stack all features to compute global statistics (mean & std)
    stacked_features = np.concatenate(combined_features, axis=0)
    mean = np.mean(stacked_features, axis=0)
    std = np.std(stacked_features, axis=0)
    # Standardize each trajectory's features (z-score normalization)
    combined_features = [(f - mean) / std for f in combined_features]
    # ----------------------------------

    # Perform clustering using the chosen observables
    discrete_trajs = get_assignments(combined_features, n_clusters=args.n_clusters)
    print(f"For domain {name}, clustering with features {features} complete.")
    print_cluster_stats(combined_features, discrete_trajs, features)
    # Save the discrete assignments
    save_assignments(discrete_trajs, assignments_files)
    # Build and save the MSM transition matrix
    transition_matrix = get_msm(discrete_trajs, lag=args.lag_msm)
    save_files(transition_matrix, transition_file)


def preprocess_all(data_dir, cluster_data_dir, features):
    with open(args.error_domains_file, "r") as f:
        error_domains = {line.strip() for line in f if line.strip()}
    with open(args.input_file, "r") as f:
        pdb_ids = [
            line.strip()
            for line in f
            if line.strip() and (line.strip() not in error_domains)
        ]
    num_total = len(pdb_ids)

    print(f"Number of trajectories to process: {len(pdb_ids)}")
    for pdb in tqdm.tqdm(pdb_ids):
        preprocess(pdb, data_dir, cluster_data_dir, features)


def check_status(name, data_dir, cluster_data_dir, features):
    feature_str = "_".join(features)
    transition_exists, assignments_exists, transition_file, assignments_files = (
        check_and_load_files(
            cluster_data_dir, name, args.n_clusters, args.temp, feature_str
        )
    )
    if transition_exists and assignments_exists:
        return "fully_processed"
    elif assignments_exists:
        return "features_only"
    else:
        return "nothing"


def count_statuses(data_dir, cluster_data_dir, features):
    with open(args.error_domains_file, "r") as f:
        error_domains = {line.strip() for line in f if line.strip()}
    with open(args.input_file, "r") as f:
        pdb_ids = [
            line.strip()
            for line in f
            if line.strip() and (line.strip() not in error_domains)
        ]
    print(f"Number of trajectories: {len(pdb_ids)}")
    status_counter = Counter()
    for name in tqdm.tqdm(pdb_ids):
        status = check_status(name, data_dir, cluster_data_dir, features)
        status_counter[status] += 1

    print("\nSummary:")
    print(f"✅ Fully processed: {status_counter['fully_processed']}")
    print(f"🟡 Features only: {status_counter['features_only']}")
    print(f"❌ Nothing processed: {status_counter['nothing']}")


# --- Main Execution ---

# Parse the features from the command line (e.g. "gr,rmsd,secondary" -> ['gr', 'rmsd', 'secondary'])
features = [feat.strip() for feat in args.features.split(",")]

if not args.count_processed:
    preprocess_all(args.data_dir, args.cluster_data_dir, features)

count_statuses(args.data_dir, args.cluster_data_dir, features)
