import os
import numpy as np
import mdgen.analysis
import tqdm
from multiprocessing import Pool
from functools import partial


def check_and_load_files(data_dir, name, num_pcca_states):
    cluster_file = os.path.join(
        data_dir, f"{name}_cluster_assignments_{num_pcca_states}.npy"
    )
    transition_file = os.path.join(
        data_dir, f"{name}_transition_matrix_{num_pcca_states}.npy"
    )
    cmsm_file = os.path.join(data_dir, f"{name}_msm_cluster_{num_pcca_states}.npy")

    cluster_exists = os.path.exists(cluster_file)
    transition_exists = os.path.exists(transition_file)
    cmsm_exists = os.path.exists(cmsm_file)

    return (
        cluster_exists,
        transition_exists,
        cmsm_exists,
        cluster_file,
        transition_file,
        cmsm_file,
    )


def save_files(
    cluster_assignments,
    transition_matrix,
    msm_cluster,
    cluster_file,
    transition_file,
    cmsm_file,
):
    np.save(cluster_file, cluster_assignments)
    np.save(transition_file, transition_matrix)
    np.save(cmsm_file, msm_cluster)


def preprocess(name, data_dir, md_dir, num_pcca_states):
    np.random.seed(137)

    # Check for existing files
    (
        cluster_exists,
        transition_exists,
        cmsm_exists,
        cluster_file,
        transition_file,
        cmsm_file,
    ) = check_and_load_files(data_dir, name, num_pcca_states)
    if cluster_exists and transition_exists and cmsm_exists:
        print(f"Files already exist for {name}, skipping computation.")
        return name, True

    try:
        # Featurization
        _, ref = mdgen.analysis.get_featurized_traj(
            f"{md_dir}/{name}/{name}", sidechains=True, cossin=True
        )

        # TICA
        tica, _ = mdgen.analysis.get_tica(ref)  # Use default lag parameter

        # K-means Clustering
        _, ref_kmeans = mdgen.analysis.get_kmeans(tica.transform(ref))

        # MSM and Transition Matrix

        msm, _, cmsm = mdgen.analysis.get_msm(ref_kmeans[0], nstates=num_pcca_states)
        transition_matrix = np.eye(num_pcca_states)
        for a, i in enumerate(cmsm.active_set):
            for b, j in enumerate(cmsm.active_set):
                transition_matrix[i, j] = cmsm.transition_matrix[a, b]

        # Save Files
        discrete = msm.metastable_assignments[ref_kmeans]
        save_files(
            ref_kmeans,
            transition_matrix,
            discrete,
            cluster_file,
            transition_file,
            cmsm_file,
        )
        print(f"Processed {name}: files saved.")
        return name, True

    except Exception as e:
        print(f"Error processing {name}: {e}")
        return name, False


def preprocess_all(data_dir, md_dir, pdb_dir, num_workers, num_pcca_states):
    # Prepare directories
    pdb_id = [
        nam for nam in os.listdir(pdb_dir) if os.path.isdir(os.path.join(pdb_dir, nam))
    ]
    pdb_id = [
        nam
        for nam in pdb_id
        if os.path.exists(f"{pdb_dir}/{nam}/{nam}.xtc")
        and os.path.exists(f"{pdb_dir}/{nam}/{nam}.pdb")
    ]

    print(f"Number of trajectories: {len(pdb_id)}")

    preprocess_partial = partial(
        preprocess, data_dir=data_dir, md_dir=md_dir, num_pcca_states=num_pcca_states
    )

    if args.num_workers > 1:
        p = Pool(args.num_workers)
        p.__enter__()
        __map__ = p.imap
    else:
        __map__ = map
    out = dict(tqdm.tqdm(__map__(preprocess_partial, pdb_id), total=len(pdb_id)))
    if args.num_workers > 1:
        p.__exit__(None, None, None)


from mdgen.parsing_ti import parse_train_args

args = parse_train_args()


preprocess_all(
    args.data_dir, args.data_dir, args.data_dir, args.num_workers, args.num_pcca_states
)
