import json
import os
from collections import OrderedDict

import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
from src.data.data_utils import generate_CV_partitions, pdb_to_coords
from src.visualization.visualization_funcs import show_CV_split_distribution
from tqdm import tqdm


def parse_ppat() -> None:
    ####################
    # Definitions      #
    ####################

    dataset = "ppat"
    force_creation = True

    # Define input paths
    pdb_dir = f"data/raw/ppat/pdb"
    raw_path = f"data/raw/ppat/ppat.xlsx"
    sheet = "S12_PPATdata"

    # Define output paths
    out_csv_path = f"data/processed/ppat/ppat.csv"
    json_path = f"data/processed/ppat/ppat.json"

    # Define GraphPart parameters
    alignment_mode = "needle"
    threads = 10
    initial_threshold = 0.5
    n_partitions = 3
    min_pp_split = 1 / (n_partitions + 1)
    threshold_inc = 0.025

    verbose = True

    # Optional visualization
    show_splits = True

    #####################
    # Build DataFrame   #
    #####################

    if not os.path.exists(out_csv_path) or force_creation:
        print(f"Generating CSV.")

        # Load and clean file with sequences
        df = pd.read_excel(raw_path, sheet_name=sheet)
        df = df.dropna().reset_index(drop=True)
        df = df.rename(
            columns={
                "Accession": "name",
                "globalfit14": "target_reg",
                "seq": "sequence",
            }
        )
        df = df[["name", "target_reg", "sequence"]]

        # Fit GMM for binary classification
        labels = GaussianMixture(n_components=2, random_state=0).fit_predict(
            X=df["target_reg"].values.reshape(-1, 1)
        )

        df["target_class"] = labels.astype(int)
        df["wildtype"] = 1

        # Compute mean value of active cluster for alternative classification
        class_threshold = df.loc[df["target_class"].astype(bool), "target_reg"].mean()
        label = df["target_reg"] > class_threshold
        df["target_class_2"] = label.astype(int)

        print(f"Threshold for high fitness: {class_threshold:.3f}")
        print(f"# high activity fitness: {int(df['target_class_2'].sum())}")
        print(f"# low activity fitness: {int(len(df) - df['target_class_2'].sum())}")

        df.to_csv(out_csv_path, index_label="index")
        print(f"Created {out_csv_path}.")
    else:
        df = pd.read_csv(out_csv_path, index_col=0)
        print("Loaded CSV.")

    #######################
    # Generate splits     #
    #######################
    threshold = np.nan

    if "part_0" not in df:
        ids, threshold = generate_CV_partitions(
            df,
            initial_threshold,
            dataset,
            alignment_mode,
            verbose,
            n_partitions,
            threads,
            min_pp_split,
            threshold_inc,
        )
        # Partition headers
        headers = [f"part_{i}" for i in range(n_partitions)]

        for header, idx in zip(headers, ids):
            df[header] = 0
            df.loc[idx, header] = 1

        # Create naive holdout partition
        df["holdout"] = ""
        df.loc[df["part_0"] == 1, "holdout"] = "train"
        df.loc[df["part_1"] == 1, "holdout"] = "val"
        df.loc[df["part_2"] == 1, "holdout"] = "test"

        df.to_csv(out_csv_path, index_label="index")

    if show_splits:
        show_CV_split_distribution(df, threshold, dataset, n_partitions)

    if not os.path.exists(json_path):
        print("Generating JSON file.")
        # Store dictionary of sequences in list
        full_list = []
        # Iterate through sequences in dataframe
        for index, row in tqdm(df.iterrows(), total=df.shape[0]):
            # Get id to find PDB-file
            seq_id = row["name"]
            seq_len = len(row["sequence"])
            # Load and parse structure
            coords = pdb_to_coords(
                f"{pdb_dir}/{seq_id}_unrelaxed_rank_1_model_3.pdb",
                identifier=seq_id[7:],
                sequence_length=seq_len,
            )

            # Create and fill dictionary over sequence
            seq_struct = OrderedDict({"index": index})
            seq_struct.update(row.to_dict())
            seq_struct.update({"coords": coords.tolist()})

            full_list.append(seq_struct)

        # Create json string and save to specified location
        json_string = json.dumps(full_list)

        with open(json_path, "w") as f:
            f.write(json_string)

        print(f"Processed JSON-file saved as {json_path}.")

    return df


if __name__ == "__main__":
    df = parse_ppat()
