import json
import logging
import os
from typing import Iterable, List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.linalg
import torch.nn.functional as F
from Bio import SeqIO
from Bio.PDB import PDBParser, Structure
from Bio.Seq import Seq
from graph_part import stratified_k_fold, train_test_validation_split

"""This file contains various helper functions that are used in other scripts."""


def read_pdb(pdb_path: str, identifier: str) -> Structure:
    """Read PDB-file and return structure

    Args:
        pdb_path (str): Full path_in to PDB-file
        identifier (str): Identifier/id for loading with Bio.PDB

    Returns:
        Bio.PDB.Structure object

    """
    parser = PDBParser()
    structure = parser.get_structure(id=identifier, file=pdb_path)
    return structure


def get_coords(
    structure: Structure, sequence_length: int, full_backbone: bool = True
) -> np.array:
    """Extract 3D coordinates of backbone residues for a protein

    Args:
        structure (Bio.PDB Structure): Protein structure
        sequence_length (int): Length of sequence for memory allocation
        full_backbone: Whether to extract N, CA, C, O or N, CA, C coordinates

    Returns:
        3D-coordinates of each residue's four backbone atoms with dimension [sequence_length, 4]
    """

    # From Bio.PD structure and sequence length, extract 3D coordinates.
    residues = structure.get_residues()
    # Pre-allocation of coordinate matrix
    coords = (
        np.zeros((sequence_length, 4, 3))
        if full_backbone
        else np.zeros((sequence_length, 3, 3))
    )
    if full_backbone:
        # Iterate through residues and extract coordinates of 4 backbone atoms
        for i, residue in enumerate(residues):
            for j, atom in enumerate(["N", "CA", "C", "O"]):
                coords[i, j] = residue[atom].get_coord()
    else:
        # Iterate through residues and extract coordinates of 3 backbone atoms
        for i, residue in enumerate(residues):
            for j, atom in enumerate(["N", "CA", "C"]):
                coords[i, j] = residue[atom].get_coord()
    return coords


def pdb_to_coords(
    pdb_path: str, identifier: str, sequence_length: int, full_backbone: bool = True
) -> np.array:
    """Read PDB-file and extract backbone coordinates

    Args:
        pdb_path (str): Full path_in to PDB file
        identifier (str): Identifier/id for reading PDB
        sequence_length (int): Number of residues
        full_backbone: Whether to extract N, CA, C, O or N, CA, C coordinates

    Returns:
        Protein backbone coordinates with dimension [sequence_length, 4] or [sequence_length, 3]

    """

    # Get structure and extract coordinates
    structure = read_pdb(pdb_path=pdb_path, identifier=identifier)
    coords = get_coords(
        structure=structure,
        sequence_length=sequence_length,
        full_backbone=full_backbone,
    )
    return coords


def generate_unstratified_splits(
    df: pd.DataFrame,
    dataset: str,
    alignment_mode: str = "needle",
    initial_threshold: float = 0.25,
    threads: int = 10,
) -> Tuple[List, float, int]:
    threshold = initial_threshold
    success = False
    sequences = df["sequence"].tolist()
    n_partitions = 3
    max_partitions = 5

    # Precompute edges
    checkpoint_path = f"data/processed/{dataset}/{dataset}_graphpart_edges.csv"
    if not os.path.exists(checkpoint_path):
        try:
            ids = stratified_k_fold(
                sequences=sequences,
                alignment_mode=alignment_mode,
                partitions=n_partitions,
                threads=threads,
                threshold=threshold,
                save_checkpoint_path=checkpoint_path,
            )
        except RuntimeError:
            threshold += 0.05

    threshold = initial_threshold
    alignment_mode = "precomputed"

    while not success:
        try:
            # Run GraphPart
            ids = stratified_k_fold(
                sequences=sequences,
                alignment_mode=alignment_mode,
                partitions=n_partitions,
                threshold=threshold,
                edge_file=checkpoint_path,
                metric_column=2,
            )
            print(f"\n{'#' * 70}")
            print(
                f"Dataset partitioned into {n_partitions} clusters at min. sequence identity {round(threshold, 3)}."
            )
            print(f"{'#' * 70}\n")
            success = True
        except RuntimeError:
            if n_partitions < max_partitions:
                print(f"\n{'#' * 70}")
                print(
                    f"GraphPart failed. Increasing number of partitions from {n_partitions} to {n_partitions + 1}."
                )
                print(f"{'#' * 70}\n")
                n_partitions += 1
            else:
                # If GraphPart fails, splits not possible. Increase threshold.
                print(f"\n{'#' * 60}")
                print(
                    f"Partition not possible. \nIncreasing threshold from {round(threshold, 3)} to "
                    f"{round(threshold + 0.05, 2)}."
                )
                print(f"{'#' * 60}\n")
                threshold += 0.05
                n_partitions = 3

    return ids, threshold, n_partitions


def precompute_edges(
    df: pd.DataFrame,
    checkpoint_path: str,
    alignment_mode: str = "needle",
    threads: int = 10,
    verbose: bool = True,
) -> None:
    # TODO: Add documentation

    # Setup logging
    if verbose:
        logging.basicConfig(level=logging.INFO)

    # Prepare values for GraphPart
    sequences = df["sequence"].tolist()

    # Precompute edges if generating test split
    if not os.path.exists(checkpoint_path):
        logging.info("Pre-computing distances.")
        try:
            stratified_k_fold(
                sequences=sequences,
                alignment_mode=alignment_mode,
                partitions=2,
                threads=threads,
                threshold=0.01,
                save_checkpoint_path=checkpoint_path,
            )
        except RuntimeError:
            # GraphPart will likely fail but still save the computed edges for further use.
            pass
    else:
        logging.info("Using pre-computed distances.")


def generate_CV_partitions(
    df: pd.DataFrame,
    initial_threshold: float,
    dataset: str,
    alignment_mode: str = "needle",
    verbose: bool = True,
    n_partitions: int = 4,
    threads: int = 10,
    min_pp_split: float = 0.2,
    threshold_inc: float = 0.05,
    checkpoint_path: Union[str, None] = None,
) -> Tuple[List[Iterable], float]:
    # TODO: Add documentation
    # Setup logging
    if verbose:
        logging.basicConfig(level=logging.INFO)

    # Prepare values for GraphPart
    threshold_limit = 1.0
    threshold = initial_threshold

    # Resort to lists as GraphPart bug prevents using pandas directly
    sequences = df["sequence"].tolist()
    labels = df["target_class"].tolist()

    # Precompute edges
    if checkpoint_path is None:
        checkpoint_path = f"data/processed/{dataset}/{dataset}_graphpart_edges.csv"
    precompute_edges(
        df=df,
        checkpoint_path=checkpoint_path,
        alignment_mode=alignment_mode,
        threads=threads,
        verbose=verbose,
    )
    # Use precomputed weights
    alignment_mode = "precomputed"

    # Proceed to main generation script
    logging.info(f"Generating {n_partitions} partitions.")
    while threshold <= threshold_limit:
        try:
            # Run GraphPart
            ids = stratified_k_fold(
                sequences=sequences,
                labels=labels,
                alignment_mode=alignment_mode,
                threads=threads,
                edge_file=checkpoint_path,
                threshold=threshold,
                metric_column=2,
                partitions=n_partitions,
            )  # metric_column warning due to bug in GraphPart

            # Inspect splits
            labels_arr = np.array(labels)
            n_eff = len([x for xs in ids for x in xs])
            p_obs = np.zeros(n_partitions)
            p_class = np.zeros(n_partitions)
            n_obs = np.zeros(n_partitions, dtype=int)

            for i in range(n_partitions):
                n_obs[i] = len(ids[i])
                p_obs[i] = n_obs[i] / n_eff
                # Class probability
                p_class[i] = labels_arr[ids[i]].sum() / n_obs[i]

            # Verify that each split has at least min_pp_split % of sequences. If not, increase threshold.
            if (p_obs < min_pp_split).any():
                logging.info(
                    f"Partition not possible at threshold {round(threshold, 3)}. Less than {min_pp_split * 100:.0f} "
                    f"% of sequences found in a split:"
                )
                [
                    logging.info(
                        f"- {p_obs[i] * 100:.2f} % ({n_obs[i]}/{n_eff}) in split {i + 1}."
                    )
                    for i in range(n_partitions)
                ]
                logging.info(
                    f"Increasing threshold from {round(threshold, 3)} to "
                    f"{round(threshold + threshold_inc, 3)} to achieve balance."
                )
                threshold += threshold_inc
                continue

            # Print split details
            logging.info(
                f"Dataset successfully split at threshold {round(threshold, 3)}."
            )
            logging.info(f"Summary:")
            [
                logging.info(
                    f"- {p_obs[i] * 100:.2f} % ({n_obs[i]}/{n_eff}) in split {i + 1}, with p(class=1) = {p_class[i] * 100:.2f} %."
                )
                for i in range(n_partitions)
            ]

            return ids, round(threshold, 3)

        except RuntimeError:
            # If GraphPart fails, splits not possible. Increase threshold.
            logging.info(
                f"Partition not possible. Increasing threshold from {round(threshold, 3)} to "
                f"{round(threshold + threshold_inc, 3)}."
            )
            threshold += threshold_inc


def extract_holdout_embeddings(
    dataset: str,
    embedding_type: str,
    target: str,
    suffix: str = None,
    split_key: str = None,
    active: bool = False,
):
    # Load CSV
    df = pd.read_csv(f"data/processed/{dataset}/{dataset}.csv", index_col=0)
    if active:
        df = df.loc[df["target_class"].astype(bool)]

    # Extract curated names and targets
    train_names = df.loc[df[split_key] == "train", "name"]
    y_train = df.loc[df[split_key] == "train", target].values
    test_names = df.loc[df[split_key] == "test", "name"]
    y_test = df.loc[df[split_key] == "test", target].values
    val_names = df.loc[df[split_key] == "val", "name"]
    y_val = df.loc[df[split_key] == "val", target].values

    if embedding_type == "ESM-1B":
        embedding_dir = f"data/processed/{dataset}/esm_1b_embeddings"
        dim = 1280
        # Extract embeddings
        embedding_train = np.zeros((len(train_names), dim))
        for i, name in enumerate(train_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_train[i] = embedding["mean_representations"][33].numpy()

        embedding_val = np.zeros((len(val_names), dim))
        for i, name in enumerate(val_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_val[i] = embedding["mean_representations"][33].numpy()

        embedding_test = np.zeros((len(test_names), dim))
        for i, name in enumerate(test_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_test[i] = embedding["mean_representations"][33].numpy()

    elif embedding_type == "ESM-2":
        embedding_dir = f"data/processed/{dataset}/esm_2_embeddings"
        dim = 2560
        # Extract embeddings
        embedding_train = np.zeros((len(train_names), dim))
        for i, name in enumerate(train_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_train[i] = embedding["mean_representations"][36].numpy()

        embedding_val = np.zeros((len(val_names), dim))
        for i, name in enumerate(val_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_val[i] = embedding["mean_representations"][36].numpy()

        embedding_test = np.zeros((len(test_names), dim))
        for i, name in enumerate(test_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_test[i] = embedding["mean_representations"][36].numpy()

    elif embedding_type == "ESM-IF1":
        embedding_dir = f"data/processed/{dataset}/esm_if1_embeddings"
        dim = 512
        # Extract embeddings
        embedding_train = np.zeros((len(train_names), dim))
        for i, name in enumerate(train_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt").numpy()
            embedding_train[i] = np.mean(embedding, axis=0)

        embedding_val = np.zeros((len(val_names), dim))
        for i, name in enumerate(val_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt").numpy()
            embedding_val[i] = np.mean(embedding, axis=0)

        embedding_test = np.zeros((len(test_names), dim))
        for i, name in enumerate(test_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt").numpy()
            embedding_test[i] = np.mean(embedding, axis=0)

    elif embedding_type in ["ONEHOT (MSA)", "ONEHOT"]:
        if embedding_type == "ONEHOT (MSA)":
            embedding_dir = f"data/processed/{dataset}/onehot_msa_encodings"
        else:
            embedding_dir = f"data/processed/{dataset}/onehot_encodings"
        # Find sequence length
        dummy = np.load(f"{embedding_dir}/{df.iloc[0]['name']}.npy").flatten()
        dim = len(dummy)

        # Extract embeddings
        embedding_train = np.zeros((len(train_names), dim))
        for i, name in enumerate(train_names):
            embedding = np.load(f"{embedding_dir}/{name}.npy")
            embedding_train[i] = embedding.flatten()

        embedding_val = np.zeros((len(val_names), dim))
        for i, name in enumerate(val_names):
            embedding = np.load(f"{embedding_dir}/{name}.npy")
            embedding_val[i] = embedding.flatten()

        embedding_test = np.zeros((len(test_names), dim))
        for i, name in enumerate(test_names):
            embedding = np.load(f"{embedding_dir}/{name}.npy")
            embedding_test[i] = embedding.flatten()

    elif embedding_type == "EVE (z)":
        embedding_dir = f"data/processed/{dataset}/eve_z/{suffix}"
        dim = 50
        # Extract embeddings
        embedding_train = np.zeros((len(train_names), dim))
        for i, name in enumerate(train_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_train[i] = embedding["representation"].numpy()

        embedding_val = np.zeros((len(val_names), dim))
        for i, name in enumerate(val_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_val[i] = embedding["representation"].numpy()

        embedding_test = np.zeros((len(test_names), dim))
        for i, name in enumerate(test_names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embedding_test[i] = embedding["representation"].numpy()

    elif embedding_type == "EVE (ELBO)":
        # Load CSV with ELBOs
        df_elbo = pd.read_csv(
            f"data/processed/{dataset}/{dataset}_EVE_ELBO_{suffix}.csv"
        )
        df = pd.merge(left=df, right=df_elbo[["name", "ELBO"]], on="name", how="left")

        # Extract curated names and targets
        embedding_train = df.loc[df[split_key] == "train", "ELBO"].values.reshape(-1, 1)
        embedding_val = df.loc[df[split_key] == "val", "ELBO"].values.reshape(-1, 1)
        embedding_test = df.loc[df[split_key] == "test", "ELBO"].values.reshape(-1, 1)

    elif embedding_type == "AF2":
        embedding_dir = f"data/processed/{dataset}/af2_embeddings"
        dim = 384
        # Extract embeddings
        embedding_train = np.zeros((len(train_names), dim))
        for i, name in enumerate(train_names):
            embedding = np.load(f"{embedding_dir}/{name}_single_repr_1_model_3.npy")
            embedding_train[i] = np.mean(embedding, axis=0)

        embedding_val = np.zeros((len(val_names), dim))
        for i, name in enumerate(val_names):
            embedding = np.load(f"{embedding_dir}/{name}_single_repr_1_model_3.npy")
            embedding_val[i] = np.mean(embedding, axis=0)

        embedding_test = np.zeros((len(test_names), dim))
        for i, name in enumerate(test_names):
            embedding = np.load(f"{embedding_dir}/{name}_single_repr_1_model_3.npy")
            embedding_test[i] = np.mean(embedding, axis=0)
    else:
        raise ValueError
    return (
        embedding_train,
        embedding_val,
        embedding_test,
        y_train,
        y_val,
        y_test,
        train_names,
        val_names,
        test_names,
    )


def extract_all_embeddings(
    dataset: str,
    embedding_type: str,
    target: str,
    suffix: str = None,
    active: bool = False,
):
    # Load CSV
    df = pd.read_csv(f"data/processed/{dataset}/{dataset}.csv", index_col=0)
    if active:
        df = df.loc[df["target_class"].astype(bool)]

    # Use only sequences also used in CV for fair comparison
    df = df.loc[df[["part_0", "part_1", "part_2"]].sum(axis=1) == 1]

    y = df[target].values
    n_obs = len(df)
    names = df["name"].tolist()

    if embedding_type == "ESM-1B":
        embedding_dir = f"data/processed/{dataset}/esm_1b_embeddings"
        dim = 1280
        # Extract embeddings
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embeddings[i] = embedding["mean_representations"][33].numpy()

    elif embedding_type == "ESM-2":
        embedding_dir = f"data/processed/{dataset}/esm_2_embeddings"
        dim = 2560
        # Extract embeddings
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embeddings[i] = embedding["mean_representations"][36].numpy()

    elif embedding_type == "ESM-IF1":
        embedding_dir = f"data/processed/{dataset}/esm_if1_embeddings"
        dim = 512
        # Extract embeddings
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt").numpy()
            embeddings[i] = np.mean(embedding, axis=0)

    elif embedding_type in ["ONEHOT (MSA)", "ONEHOT"]:
        if embedding_type == "ONEHOT (MSA)":
            embedding_dir = f"data/processed/{dataset}/onehot_msa_encodings"
        else:
            embedding_dir = f"data/processed/{dataset}/onehot_encodings"
        # Find sequence length
        dummy = np.load(f"{embedding_dir}/{df.iloc[0]['name']}.npy").flatten()
        dim = len(dummy)

        # Extract embeddings
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = np.load(f"{embedding_dir}/{name}.npy")
            embeddings[i] = embedding.flatten()

    elif embedding_type == "EVE (z)":
        embedding_dir = f"data/processed/{dataset}/eve_z/{suffix}"
        dim = 50
        # Extract embeddings
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = torch.load(f"{embedding_dir}/{name}.pt")
            embeddings[i] = embedding["representation"].numpy()

    elif embedding_type == "EVE (ELBO)":
        # Load CSV with ELBOs
        df_elbo = pd.read_csv(
            f"data/processed/{dataset}/{dataset}_EVE_ELBO_{suffix}.csv"
        )
        df = pd.merge(left=df, right=df_elbo[["name", "ELBO"]], on="name", how="left")

        # Extract curated names and targets
        embeddings = df["ELBO"].values.reshape(-1, 1)

    elif embedding_type == "AF2":
        embedding_dir = f"data/processed/{dataset}/af2_embeddings"
        dim = 384
        # Extract embeddings
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = np.load(f"{embedding_dir}/{name}_single_repr_1_model_3.npy")
            embeddings[i] = np.mean(embedding, axis=0)

    elif embedding_type.endswith("(aligned)"):
        if embedding_type == "ESM-1B (aligned)":
            embedding_dir = f"data/processed/{dataset}/esm_1b_aln_embeddings"
        elif embedding_type == "ESM-IF1 (aligned)":
            embedding_dir = f"data/processed/{dataset}/esm_if1_aln_embeddings"
        elif embedding_type == "AF2 (aligned)":
            embedding_dir = f"data/processed/{dataset}/af2_aln_embeddings"
        else:
            raise ValueError

        dummy = np.load(f"{embedding_dir}/{df.iloc[0]['name']}.npy").flatten()
        dim = len(dummy)
        embeddings = np.zeros((n_obs, dim))
        for i, name in enumerate(names):
            embedding = np.load(f"{embedding_dir}/{name}.npy")
            embeddings[i] = embedding

    else:
        raise NotImplementedError
    return embeddings, y, names


if __name__ == "__main__":
    dataset: str = "cm"
    embedding_type: str = "ESM-IF1 (aligned)"
    target: str = "target_reg"
    suffix = None
    active: bool = False
    embeddings, y, names = extract_all_embeddings(
        dataset, embedding_type, target, suffix, active
    )
