import json
import logging
import os
from pathlib import Path
import pickle
import sys
from typing import Any, Dict, List, Optional, Set, Tuple

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix


def build_start_adjacency_matrix(df: pd.DataFrame, score: str) -> Tuple[np.ndarray, Dict[str, int]]:
    """ Returns:
        - scores: square matrix of similarity scores (TM score, sequence identity, ...)
        - prots: list of names of proteins
    """
    prots = set(list(df['query'])) | set(list(df['target']))
    n_prots = len(prots)
    prots = list(prots)
    prots.sort() # proteins sorted alphabetically: A -> Z
    prots = [i.split('_')[0] for i in prots]
    protein_to_idx = {j: i for i, j in enumerate(prots)}
    
    # build nxn symmetric scores matrix 
    adjacency = np.zeros((n_prots, n_prots))
    all_scores = {}

    for q, t, s in zip(list(df['query']), list(df['target']), list(df[score])):
        q = q.split('_')[0]
        t = t.split('_')[0]
        all_scores[f'{q},{t}'] = s

    for k, s in all_scores.items():
        q, t = k.split(',')
        reversed_pair = f'{t},{q}'
        reversed_score = 0
        if reversed_pair in all_scores:
            reversed_score = all_scores[reversed_pair]
        max_score = max(s, reversed_score)
    
        # take the max TM score of (seq1, seq2)
        i1, i2 = protein_to_idx[q], protein_to_idx[t]
        adjacency[i1, i2] = max_score
        adjacency[i2, i1] = max_score

    np.fill_diagonal(adjacency, 1)
    adjacency = csr_matrix(adjacency)
    is_sparse_matrix_symmetric(adjacency)
    return adjacency, protein_to_idx

def check_inter_clusters_dists(clusters: List[List[int]], thres: float, adjacency: csr_matrix, prot2idx: Dict[str, int], 
                               desc: Optional[str] = None) -> Any:
    """Check that inter-cluster distance are below desired threshold."""
    for c in clusters:
        arr1 = [prot2idx[i] for i in c]
        arr1s = set(arr1)
        arr2 = [k for k in prot2idx.values() if k not in arr1s]
        if (adjacency[arr1, :][:, arr2] >= thres).sum():
            raise AssertionError(f"Inter-cluster distances should all be < {thres} but some are > {thres}!")

def configure_logger(logfile: os.PathLike) -> Any:
    logging.basicConfig(
            handlers=[
                logging.StreamHandler(sys.stdout),
                logging.FileHandler(logfile, mode="w"),
            ],
            level=logging.INFO,
            format="%(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )  

def extract_connected_components(graph: nx.Graph) -> List[Set[int]]:
    connected_components = [list(c) for c in nx.connected_components(graph)]
    connected_components.sort(key = lambda x:len(x), reverse=True)
    return connected_components

def get_cluster_sizes(clusters: List[List[int]]) -> List[int]:
    return [len(c) for c in clusters]

def get_dataset_statistics(df: pd.DataFrame, outfile: os.PathLike) -> Any:
    splits, train_percents, val_percents, test_percents, diffs = [], [], [], [], []
    for col in df.columns:
        if "split" in col:
            train = len(df[df[col] == 'train'])
            val = len(df[df[col] == 'val'])
            test = len(df[df[col] == 'test'])
            tot = train + val + test 
            train_p = round(train / tot * 100, 2)
            val_p = round(val / tot * 100, 2)
            test_p = round(test / tot * 100, 2)
            diff_to_80 = round(abs(train_p - 80), 2)

            splits.append(col)
            train_percents.append(train_p)
            val_percents.append(val_p)
            test_percents.append(test_p)
            diffs.append(diff_to_80)

    df_stats = pd.DataFrame(list(zip(splits, train_percents, val_percents, test_percents, diffs)), 
                            columns=["Data split", "Train %", "Val %", "Test %", "Train diff to 80"])   
    df_stats.to_csv(outfile, index=False)
    
def is_sparse_matrix_symmetric(matrix: csr_matrix) -> Any:
    assert (matrix!=matrix.T).nnz==0

def json_load(path: str) -> Dict[int, int]:
    with open(path, "r") as f:
        out = json.load(f)
        return out

def load_foldseek_tsv(path: str) -> pd.DataFrame:
    df = pd.read_csv(path, sep='\t')
    df['query'] = df['query'].map(lambda x: x.split('_model_1_relaxed.pdb')[0])
    df['query'] = df['query'].map(lambda x: x.split('_model_1_unrelaxed.pdb')[0])
    df['target'] = df['target'].map(lambda x: x.split('_model_1_relaxed.pdb')[0])
    df['target'] = df['target'].map(lambda x: x.split('_model_1_unrelaxed.pdb')[0])
    df['lddt'] = df['lddt'].astype(float)
    return df

def pickle_dump(path: str, file: Any) -> Any:
    with open(path, 'wb') as f:
        pickle.dump(file, f, protocol=pickle.HIGHEST_PROTOCOL)

def pickle_load(path: str) -> Any:
    with open(path, "rb") as f:
        out = pickle.load(f)
    return out

def plot_cluster_sizes(clusters: List[List[int]], title: str, figs_dir: str, figname: str, 
                       save_fig: bool=True) -> Any:
    plt.figure()
    sizes = get_cluster_sizes(clusters)
    x = np.arange(1,len(sizes)+1)
    plt.scatter(x, sizes, marker=".", s=5, c="royalblue")
    plt.ylabel("size")
    plt.title(title)
    if save_fig:
        plt.savefig(Path(figs_dir) / Path(f"{figname}.png"))
    plt.close()

def plot_no_clusters(thresholds: List[float], no_clusters: List[int]) -> Any:
    x_label = [str(thr) for thr in thresholds]
    fig, ax = plt.subplots(figsize=(4,4))
    ax.bar(x_label, no_clusters)
    plt.show()
    plt.close()

def remove_edges_in_adjacency(adjacency: csr_matrix, prot2idx: Dict[str, int], proteins: List[str]) -> csr_matrix:
    removed = [prot2idx[i] for i in proteins]
    mask = np.ones(adjacency.shape[0], dtype=bool)
    mask[removed] = False
    adjacency = adjacency.multiply(mask)
    adjacency = adjacency.T
    adjacency = adjacency.multiply(mask)
    adjacency = adjacency.tocsr()
    return adjacency

def select_subgraph_from_protein_list(proteins: List[int], adjacency: csr_matrix) -> Tuple[nx.Graph, pd.DataFrame]:
    adjacency = adjacency[:, proteins][proteins, :]
    graph = nx.from_scipy_sparse_array(adjacency)
    return graph, adjacency

def update_adjacency_with_remaining_proteins(adjacency: csr_matrix, proteins: List[int]) -> csr_matrix:
    return adjacency[proteins][:, proteins]

def update_sparse_adjacency_threshold(adj: csr_matrix, threshold: float, use_weights: bool=True) -> csr_matrix:
    """Updates adjacency matrix for constructing an undirected, weighted graph where nodes are connected
    by an edge if their similarityy score is equal or above the input threshold. That edge is weighted
    by the score value."""
    if use_weights: 
        adj = (adj >= threshold).multiply(adj)
        return adj
    else:
        return adj >= threshold

def update_protein_mappings(sorted_proteins: List[str]) -> Tuple[Dict[str, int], Dict[int, str]]:
    idx_to_protein, protein_to_idx = {}, {}
    for i, j in enumerate(sorted_proteins):
        idx_to_protein[i] = j
        protein_to_idx[j] = i
    return idx_to_protein, protein_to_idx
