import logging
from pathlib import Path
import random
import time
from typing import Any, List, Dict, Set

import networkx as nx
import pandas as pd
from scipy.sparse import csr_matrix

from dataset.dataset_utils import build_start_adjacency_matrix, check_inter_clusters_dists, configure_logger, extract_connected_components, load_foldseek_tsv, \
      pickle_load, pickle_dump, plot_cluster_sizes, update_adjacency_with_remaining_proteins, update_sparse_adjacency_threshold, update_protein_mappings


class MakeDataSplits:

    def __init__(
        self,
        score: str, 
        is_foldseek: bool,
        cov: float,
        thresholds: List[float],
        n_out: int,
        seed: int,
        proteins_to_exclude: List[str]=[],
        make_plots: bool=True,
        verbose: bool=True,
    ) -> None:

        self.score = score
        self.is_foldseek = is_foldseek
        self.cov = cov
        self.thresholds = thresholds
        self.n_out = n_out
        self.seed = seed
        self.proteins_to_exclude = proteins_to_exclude
        self.make_plots = make_plots
        self.verbose = verbose

    def sample_clusters(self, clusters: List[List[int]]) -> List[List[int]]:
        random.seed(self.seed, version=2)
        heldout = random.sample(clusters, self.n_out)
        heldout.sort(key=len, reverse=True)
        return heldout

    def assign_heldout_clusters_to_val_and_test(self, heldout: Dict[float, List[List[int]]]) -> Dict[float, Dict[str, List[List[int]]]]:
        """For each similarity threshold, randomly assign held-out clusters to val and test sets."""
        splits = {}
        for thr in self.thresholds:
            val_clusters = random.sample(heldout[thr], int(self.n_out/2))
            test_clusters = [c for c in heldout[thr] if c not in val_clusters]
            splits[thr] = {}
            splits[thr]["val"] = val_clusters
            splits[thr]["test"] = test_clusters
        return splits

    def assign_remaining_clusters_to_train(self, splits: Dict[float, Dict[str, List[List[int]]]], clusters: List[Set[int]], heldout_clusters: List[List[str]], idx_to_protein: Dict[int, str]
        ) -> Dict[float, Dict[str, List[List[int]]]]:
        """Assign all remaining clusters at highest similarity threshold to train set."""
        clusters = [[idx_to_protein[p] for p in c] for c in clusters]
        train_clusters = [c for c in clusters if c not in heldout_clusters]
        splits[self.thresholds[-1]]["train"] = train_clusters
        return splits

    def check_splits_inter_cluster_dists(self, splits: Dict[float, Dict[str, List[List[str]]]], adjacency: csr_matrix, prot2idx: Dict[str, int]) -> Any:
        for thr in self.thresholds:
            for k, vals in splits[thr].items():
                if k == 'train':
                    continue
                check_inter_clusters_dists(vals, thr, adjacency, prot2idx, k)

    def check_splits_protein_nb(self, splits: Dict[float, Dict[str, List[List[str]]]], n: int) -> Any:
        all_prots = []
        for thr in self.thresholds:
            for clusters in splits[thr].values():
                for c in clusters:
                    all_prots.extend(c)
        if not len(set(all_prots)) == n:
            raise AssertionError(f"Final dataset doesn't have same number of unique proteins as initial dataset ({len(set(all_prots))} vs {len(n)} proteins, resp.)!")

    def __call__(self, path_to_tsv: str, path_to_start_clusters: str, output_dir: str, foldseek_score: str="alntmscore", mmseqs_score: str="fident") -> Any:
        t0 = time.time()

        Path(output_dir).mkdir(parents=True, exist_ok=True)
        output_dir = Path(output_dir) / Path(f"{self.score}_{'_'.join([str(i) for i in self.thresholds])}")
        figs_dir = output_dir / Path(f"figs_n_{self.n_out}") / Path(f"seed_{self.seed}")
        figs_dir.mkdir(parents=True, exist_ok=True)
        log_file = output_dir / Path(f"log_n_{self.n_out}_seed_{self.seed}.log")
        output_file = Path(output_dir) / Path(f"splits_{self.score}_n_{self.n_out}_seed_{self.seed}.pkl")

        configure_logger(log_file)

        logging.info("Making data splits for:")
        for k, v in vars(self).items():
            logging.info(f"{k}: {v}")

        # load data
        logging.info("Loading data")
        if self.is_foldseek:
            df = load_foldseek_tsv(path_to_tsv)
            score = foldseek_score
        else:
            df = df = pd.read_csv(path_to_tsv, sep='\t')
            score = mmseqs_score
        df = df[df.tcov >= self.cov]
        start_adjacency, start_protein_to_idx = build_start_adjacency_matrix(df, score)

        # clusters obtained at lowest threshold from community detection
        thr = self.thresholds[0]
        start_clusters = pickle_load(path_to_start_clusters)
        plot_cluster_sizes(start_clusters, f"Clusters at {self.score} {thr}", figs_dir, f"clusters_{self.score}_{thr}")
        start_proteins = sorted([p for c in start_clusters for p in c])
        logging.info(f"{len(start_clusters)} clusters at {self.score} {thr}")
        logging.info(f"{len(start_proteins)} start proteins")
        start_protein_inds = [start_protein_to_idx[p] for p in start_proteins]
        start_adjacency = update_adjacency_with_remaining_proteins(start_adjacency, start_protein_inds)
        _, start_protein_to_idx = update_protein_mappings(start_proteins)

        # sample clusters at lowest threshold
        heldout = {}
        sampled_clusters = self.sample_clusters(start_clusters)
        plot_cluster_sizes(sampled_clusters, f"Held-out clusters at {self.score} {thr}", figs_dir, f"heldout_clusters_{self.score}_{thr}")
        heldout[thr] = sampled_clusters

        # update adjacency and protein mapping at lowest threshold
        sampled_proteins = set([p for c in sampled_clusters for p in c])
        remaining_proteins = sorted([p for p in start_proteins if p not in sampled_proteins])
        remaining_inds = [start_protein_to_idx[p] for p in remaining_proteins]
        adjacency = update_adjacency_with_remaining_proteins(start_adjacency, remaining_inds)
        idx_to_protein, protein_to_idx = update_protein_mappings(remaining_proteins)

        # iterate over higher thresholds
        for thr in self.thresholds[1:]:
            logging.info(f"{self.score} {thr}")
            adjacency = update_sparse_adjacency_threshold(adjacency, threshold=thr)
            logging.info(f"Adjacency shape: {adjacency.shape}")
            graph = nx.from_scipy_sparse_array(adjacency)
            clusters = extract_connected_components(graph)
            logging.info(f"{len(clusters)} clusters")
            plot_cluster_sizes(clusters, f"Clusters at {self.score} {thr}", figs_dir, f"clusters_{self.score}_{thr}")

            # sample holdout clusters
            sampled_clusters = self.sample_clusters(clusters)
            plot_cluster_sizes(sampled_clusters, f"Held-out clusters at {self.score} {thr}", figs_dir, f"heldout_clusters_{self.score}_{thr}")
            sampled_clusters = [[idx_to_protein[p] for p in c] for c in sampled_clusters]
            heldout[thr] = sampled_clusters

            # get remaining proteins
            sampled_proteins = set([p for c in sampled_clusters for p in c])
            logging.info(f"{len(sampled_proteins)} sampled proteins")
            remaining_proteins = sorted([p for p in list(protein_to_idx.keys()) if p not in sampled_proteins])
            logging.info(f"{len(remaining_proteins)} remaining proteins")

            if thr == self.thresholds[-1]:
                break

            # retain only remaining proteins in adjacency
            remaining_inds = [protein_to_idx[p] for p in remaining_proteins]    
            adjacency = update_adjacency_with_remaining_proteins(adjacency, remaining_inds)
            logging.info(f"Adjacency shape: {adjacency.shape}")
            idx_to_protein, protein_to_idx = update_protein_mappings(remaining_proteins)
        
        splits = self.assign_heldout_clusters_to_val_and_test(heldout)
        splits = self.assign_remaining_clusters_to_train(splits, clusters, sampled_clusters, idx_to_protein)

        self.check_splits_inter_cluster_dists(splits, start_adjacency, start_protein_to_idx)
        self.check_splits_protein_nb(splits, len(start_proteins))
        
        # Summary statistics
        if self.verbose:
            logging.info(f"Heldout proteins: {len(start_proteins) - len(remaining_inds)}")
            val_p = sum([sum([len(c) for c in splits[t]['val']]) for t in self.thresholds]) / len(start_proteins)
            test_p = sum([sum([len(c) for c in splits[t]['test']]) for t in self.thresholds]) / len(start_proteins)
            train_p = 1 - val_p - test_p
            logging.info(f"% train in total: {train_p:.2f}")
            logging.info(f"% val in total: {val_p:.2f}")
            logging.info(f"% test in total: {test_p:.2f}")

        pickle_dump(output_file, splits)
        logging.info(f"Total time, {(time.time() - t0):.2f}s")


if __name__ == "__main__":
    BASE = Path("/ddots")
    DATA_ROOT = BASE / Path("data/deepfri_go")
    FOLDSEEK_FILE = DATA_ROOT / Path("foldseek/s_7.5_covmode_1_cov_0.5_tmthr_0.0_aln_1_exh/align.tsv")
    MMSEQS_FILE = DATA_ROOT / Path("mmseqs/s_7.5_covmode_1_cov_0.5_idthr_0.0_aln_3/align.tsv")
    CLUSTERS_DIR = DATA_ROOT / Path('clusters')
    SPLITS_DIR = CLUSTERS_DIR / Path("data_splits")
    CLUSTERS_DIR = DATA_ROOT / Path('clusters')
    TM_DIR = CLUSTERS_DIR / Path("tm_clusters")
    SEQID_DIR = CLUSTERS_DIR / Path("seqid_clusters")
    

    n_out = 700
    for seed in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]:
        ALGO_DIR = SEQID_DIR / Path("algo_slow")
        CLUSTERS_FILE = ALGO_DIR / Path(f"clusters_seqid_0.3_seed_{seed}.pkl")  # Path(f"clusters_TM_0.5_seed_{seed}.pkl")
        
        makesplits = MakeDataSplits(score="seqid", is_foldseek=False, cov=0.8, thresholds=[0.3, 0.4, 0.5, 0.7, 0.9],   # [0.5, 0.6, 0.7, 0.8, 0.9]
                                    n_out=n_out, seed=seed, proteins_to_exclude=None, make_plots=True, verbose=True)
        makesplits(path_to_tsv=MMSEQS_FILE, path_to_start_clusters=CLUSTERS_FILE, output_dir=SPLITS_DIR)