import logging
from pathlib import Path
import time
from typing import Any, List, Dict, Set, Tuple

import networkx as nx
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix

from dataset.dataset_utils import configure_logger, load_foldseek_tsv, build_start_adjacency_matrix, update_sparse_adjacency_threshold, update_adjacency_with_remaining_proteins, \
    extract_connected_components, plot_cluster_sizes, check_inter_clusters_dists, remove_edges_in_adjacency, pickle_dump, update_protein_mappings


class CommunityDetection:

    def __init__(
        self,
        score: str, 
        threshold: float,
        is_foldseek: bool,
        cov: float,
        seed: int,
        resolution: str,
        proteins_to_exclude: List[str]=[],
        do_fast: bool=True,
        make_plots: bool=True,
        verbose: bool=True,
    ) -> None:

        self.cov = cov
        self.score = score
        self.thr = threshold
        self.is_foldseek = is_foldseek
        self.seed = seed
        self.resolution = resolution
        self.proteins_to_exclude = proteins_to_exclude
        self.do_fast = do_fast
        self.make_plots = make_plots
        self.verbose = verbose

    def remove_top_nodes_one_by_one(self, adjacency: csr_matrix, communities: List[List[int]], protein_to_idx: Dict[str, int], idx_to_protein: Dict[int, str]
                                    ) -> Tuple[List[List[str]], List[str], csr_matrix, Dict[int, str]]:

        clusters, removed = [], []
        community = [protein_to_idx[i] for i in communities[0]]

        while len(community) > 1:
            out = get_node_with_most_inter_community_edges(adjacency, community)
            if out['n_edges'] > 0:
                node = out['node']
                remaining_inds = [i for i in range(adjacency.shape[0]) if i!= node]
                adjacency = update_adjacency_with_remaining_proteins(adjacency, remaining_inds)
                communities = remove_protein_from_community(communities, out['index'])
                print(f"{len(removed)} removed")
                removed.append(idx_to_protein[node])
                remaining_proteins = list(protein_to_idx.keys())
                del remaining_proteins[node]
                idx_to_protein, protein_to_idx = update_protein_mappings(remaining_proteins)
                # only sort if len(comm[1]) > comm[0]: in which case just need to switch them
                if len(communities[1]) > len(communities[0]):
                    communities = [communities[1]] + [communities[0]] + communities[2:]
            else:
                clusters.append(communities[0])
                communities = communities[1:]   
            community = [protein_to_idx[i] for i in communities[0]]

        # Add all remaining singletons to clusters
        clusters.extend(communities)
        return clusters, removed, adjacency, idx_to_protein

    def remove_top_nodes_together(self, adjacency: csr_matrix, communities: List[List[int]], protein_to_idx: Dict[str, int], idx_to_protein: Dict[int, str]
                                  ) -> Tuple[List[List[str]], List[str], csr_matrix, Dict[int, str]]:
        
        clusters, removed = [], []
        while len(communities[0]) > 1:
            community = [protein_to_idx[i] for i in communities[0]]  
            nodes_to_remove = get_all_nodes_with_inter_community_edges(adjacency, community)

            if not len(nodes_to_remove):
                cluster = [idx_to_protein[i] for i in community]
                clusters.append(cluster)
                print(f"community size: {len(communities)}, new cluster size: {len(clusters[-1])}")
                communities = communities[1:]
                continue
                
            proteins_to_remove = [idx_to_protein[i] for i in nodes_to_remove]
            removed.append(proteins_to_remove)
            proteins_to_remove = set(proteins_to_remove)
            remaining_inds = get_remaining_protein_indices(adjacency.shape[0], nodes_to_remove)
            adjacency = update_adjacency_with_remaining_proteins(adjacency, remaining_inds)
            remaining_proteins = [idx_to_protein[i] for i in remaining_inds]
            cluster = [i for i in communities[0] if i not in proteins_to_remove]
            clusters.append(cluster)
                
            idx_to_protein, protein_to_idx = {}, {}
            for i, j in enumerate(remaining_proteins):
                idx_to_protein[i] = j
                protein_to_idx[j] = i
            
            print(f"community size: {len(communities)}, new cluster size: {len(clusters[-1])}")
            communities = communities[1:]

        # Add all remaining singletons to clusters
        clusters.extend(communities)
        removed = [i for c in removed for i in c]
        return clusters, removed, adjacency, idx_to_protein
    
    def run_community_detection(self, path_to_tsv: str, output_dir: str, foldseek_score: str="alntmscore", mmseqs_score: str="fident") -> Any:
        t0 = time.time()

        Path(output_dir).mkdir(parents=True, exist_ok=True)
        algo_dir = "algo_fast" if self.do_fast else "algo_slow"
        output_dir = Path(output_dir) / Path(algo_dir)  
        output_file = output_dir / Path(f"clusters_{self.score}_{self.thr}_seed_{self.seed}.pkl")
        figs_dir = output_dir / Path("figs") / Path(f"{self.score}_{self.thr}_seed_{self.seed}")
        figs_dir.mkdir(parents=True, exist_ok=True)
        log_file = output_dir / Path(f"log_{self.score}_{self.thr}_seed_{self.seed}.log")
        
        configure_logger(log_file)

        logging.info("Running community detection for:")
        for k, v in vars(self).items():
            logging.info(f"{k}: {v}")
    
        # load data
        logging.info("Loading data")
        if self.is_foldseek:
            df = load_foldseek_tsv(path_to_tsv)
            score = foldseek_score
        else:
            df = pd.read_csv(path_to_tsv, sep='\t')
            score = mmseqs_score
        df = df[df.tcov >= self.cov]
        adjacency_start, protein_to_idx_start = build_start_adjacency_matrix(df, score)
        protein_to_idx = protein_to_idx_start.copy()
        idx_to_protein = {j: i for i, j in protein_to_idx.items()}

        # Remove edges below desired threshold
        adjacency = update_sparse_adjacency_threshold(adjacency_start, threshold=self.thr)
        logging.info(f"{adjacency.shape[0]} proteins")

        # Build start graph
        graph = nx.from_scipy_sparse_array(adjacency)
        start_clusters = extract_connected_components(graph)
        plot_cluster_sizes(start_clusters, f"Initial clusters at {self.score} {self.thr}", figs_dir, f"clusters_start_{'fast' if self.do_fast else 'slow'}")

        # Do community detection
        logging.info(f"Finding communities, {(time.time() - t0):.2f}s")
        communities = find_communities(graph, seed=self.seed, resolution=self.resolution)
        communities = [[idx_to_protein[p] for p in c] for c in communities]
        logging.info(f"{len(communities)} communities")
        plot_cluster_sizes(communities, f"Initial communities at {self.score} {self.thr}", figs_dir, f"communities_start_{'fast' if self.do_fast else 'slow'}")

        # Remove connections between communities
        logging.info(f"Separating communities, {(time.time() - t0):.2f}s")
        if self.do_fast:
            clusters, removed_proteins, adjacency, idx_to_protein = self.remove_top_nodes_together(adjacency, communities, protein_to_idx, idx_to_protein)
        else:
            clusters, removed_proteins, adjacency, idx_to_protein = self.remove_top_nodes_one_by_one(adjacency, communities, protein_to_idx, idx_to_protein)

        logging.info(f"{len(clusters)} communities")
        plot_cluster_sizes(clusters, "Communities after separation", figs_dir, f"communities_end_{'fast' if self.do_fast else 'slow'}")

        # Extract connected components from final graph
        logging.info(f"{adjacency.shape[0]} proteins")
        graph = nx.from_scipy_sparse_array(adjacency)
        clusters = extract_connected_components(graph)
        clusters = [[idx_to_protein[i] for i in c] for c in clusters]
        logging.info(f"End of community separation, {(time.time() - t0):.2f}s")
        logging.info(f"{len(clusters)} clusters")
        plot_cluster_sizes(clusters, "Final clusters", figs_dir, f"clusters_end_{'fast' if self.do_fast else 'slow'}")
        logging.info(f"{len(removed_proteins)} removed proteins")

        # Sanity check
        adj_test = remove_edges_in_adjacency(adjacency_start, protein_to_idx_start, removed_proteins)
        check_inter_clusters_dists(clusters, self.thr, adj_test, protein_to_idx_start)
        logging.info(f"Total time, {(time.time() - t0):.2f}s")
        pickle_dump(output_file, clusters)

def find_communities(graph: nx.Graph, seed: int, resolution: float) -> List[Set[int]]:
    """Resolution: if < 1, the algorithm favors larger communities; > 1, favors smaller communities."""
    communities = nx.community.louvain_communities(graph, seed=seed, weight="weight", resolution=resolution)
    communities = [sorted(list(c)) for c in communities]
    communities.sort(key=len, reverse=True)
    return communities

def get_all_nodes_with_inter_community_edges(adjacency: csr_matrix, community: List[int]) -> np.ndarray:
    """Returns indices of all nodes with inter-community edges."""
    # select inter-community connections
    comm = set(community)
    arr = np.array(community)
    out_inds = [i for i in range(adjacency.shape[0]) if i not in comm]
    selected = adjacency[community, :][:, out_inds]

    # count number of inter-community connections
    summed = np.asarray((selected > 0).sum(axis=1)).squeeze(-1)
    
    # select all nodes that have inter-community connections
    return arr[(summed > 0)]

def get_node_with_most_inter_community_edges(adjacency: csr_matrix, community: List[int]) -> Dict[str, int]:
    comm = set(community)
    out_inds = [i for i in range(adjacency.shape[0]) if i not in comm]
    selected = adjacency[community, :][:, out_inds]

    # count number of inter-community connections
    summed = np.asarray((selected > 0).sum(axis=1)).squeeze(-1)
    n_edges = summed.max()
    out = {'n_edges': n_edges}
    if n_edges > 0:
        i = summed.argmax()
        top_node = community[i]
        out['node'] = top_node
        out['index'] = i
    return out

def get_remaining_protein_indices(n_prots: int, to_remove: np.ndarray) -> np.ndarray:
    to_remove = set(to_remove)
    return [i for i in range(n_prots) if i not in to_remove]

def remove_protein_from_community(communities: List[List[int]], i: int) -> List[int]:
    del communities[0][i]
    return communities

def update_adjacency_with_remaining_proteins(adjacency: np.ndarray, remaining: List[int]) -> np.ndarray:
    return adjacency[remaining][:, remaining]


if __name__ == "__main__":
    BASE = Path("/ddots/")
    DATA_ROOT = BASE / Path("data/deepfri_go")
    FOLDSEEK_FILE = DATA_ROOT / Path("foldseek/s_7.5_covmode_1_cov_0.5_tmthr_0.0_aln_1_exh/align.tsv")
    MMSEQS_FILE = DATA_ROOT / Path("mmseqs/s_7.5_covmode_1_cov_0.5_idthr_0.0_aln_3/align.tsv")
    CLUSTERS_DIR = DATA_ROOT / Path('clusters')
    TM_DIR = CLUSTERS_DIR / Path("tm_clusters")
    SEQID_DIR = CLUSTERS_DIR / Path("seqid_clusters")

    seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
    for seed in seeds:
        clust = CommunityDetection(score="TM", threshold=0.5, is_foldseek=True, cov=0.8, seed=seed, resolution=2, proteins_to_exclude=None, 
                                do_fast=False, make_plots=False, verbose=True)
        clust.run_community_detection(path_to_tsv=FOLDSEEK_FILE, output_dir=TM_DIR)
  