"""
Run heirarchical clustering on the pairwise distance matrix between all pairs of files
"""

import os, sys
import re
import json
import logging
from glob import glob
from pathlib import Path
import itertools
import argparse
from typing import *
import multiprocessing as mp

import numpy as np
import pandas as pd
import scipy.spatial as sp, scipy.cluster.hierarchy as hc
import seaborn as sns

from train import get_train_valid_test_sets

from foldingdiff import tmalign

# :)
SEED = int(
    float.fromhex("2254616977616e2069732061206672656520636f756e74727922") % 10000
)


def int_getter(x: str) -> int:
    """Fetches integer value out of a string"""
    matches = re.findall(r"[0-9]+", x)
    assert len(matches) == 1
    return int(matches.pop())


def get_pairwise_tmscores(
    fnames: Collection[str], sctm_scores_json: Optional[str] = None
) -> pd.DataFrame:
    """
    Get the pairwise TM scores across all fnames. If sctm_scores_json is given
    then we filter the fnames by passing scTM scores.
    """
    logging.info(f"Computing pairwise distances between {len(fnames)} pdb files")

    bname_getter = lambda x: os.path.splitext(os.path.basename(x))[0]
    if sctm_scores_json:
        with open(sctm_scores_json) as source:
            sctm_scores = json.load(source)
        fnames = [f for f in fnames if sctm_scores[bname_getter(f)] >= 0.5]
        logging.info(f"{len(fnames)} structures have scTM scores >= 0.5")

    # for debugging
    # fnames = fnames[:50]

    pairs = list(itertools.combinations(fnames, 2))
    pool = mp.Pool(mp.cpu_count())
    values = list(pool.starmap(tmalign.run_tmalign, pairs, chunksize=25))
    pool.close()
    pool.join()

    bnames = [bname_getter(f) for f in fnames]
    retval = pd.DataFrame(1.0, index=bnames, columns=bnames)
    for (k, v), val in zip(pairs, values):
        retval.loc[bname_getter(k), bname_getter(v)] = val
        retval.loc[bname_getter(v), bname_getter(k)] = val
    assert np.allclose(retval, retval.T)
    return retval


def build_parser():
    """Build a CLI parser"""
    parser = argparse.ArgumentParser(
        usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    g = parser.add_mutually_exclusive_group()
    g.add_argument("--dirname", type=str, help="Directory of PDB files to analyze")
    g.add_argument("--testsubset", type=int, help="Subset of test set sequences to run")
    parser.add_argument(
        "--sctm", type=str, required=False, default="", help="scTM scores to filter by"
    )
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default="tmscore_hclust.pdf",
        help="PDF file to write output clustering plot",
    )
    return parser


def main():
    """Run the script"""
    parser = build_parser()
    args = parser.parse_args()

    # Get the files
    if args.dirname:
        fnames = sorted(
            glob(os.path.join(args.dirname, "*.pdb")),
            key=lambda x: int_getter(os.path.basename(x)),
        )
        assert fnames, f"{args.dirname} does not contain any pdb files"
    elif args.testsubset:
        # We only care about fnames here
        *_, test_subset = get_train_valid_test_sets(
            max_seq_len=128,
            min_seq_len=50,
            seq_trim_strategy="discard",
        )
        rng = np.random.default_rng(SEED)
        idx = rng.choice(
            len(test_subset.filenames), size=args.testsubset, replace=False
        )
        fnames = [test_subset.filenames[i] for i in idx]
    else:
        raise NotImplementedError

    # TMscore of 1 = perfect match --> 0 distance, so need 1.0 - tmscore
    pdist_df = 1.0 - get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)

    # https://stackoverflow.com/questions/38705359/how-to-give-sns-clustermap-a-precomputed-distance-matrix
    # https://stackoverflow.com/questions/57308725/pass-distance-matrix-to-seaborn-clustermap
    m = "average"  # Trippe uses average here
    linkage = hc.linkage(
        sp.distance.squareform(pdist_df), method=m, optimal_ordering=False
    )

    c = sns.clustermap(
        pdist_df,
        row_linkage=linkage,
        col_linkage=linkage,
        method=None,
        row_cluster=True,
        col_cluster=True,
        vmin=0.0,
        vmax=1.0,
        xticklabels=False,
        yticklabels=False,
        cbar_kws={"label": r"$d(x, y) = 1 - \mathrm{TMscore}(x, y)$"},
    )
    c.savefig(args.output)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()
