import json
import os.path
from collections import OrderedDict

import numpy as np
import pandas as pd
from src.data.data_utils import generate_CV_partitions, pdb_to_coords
from src.visualization.visualization_funcs import show_CV_split_distribution
from tqdm import tqdm


def parse_cm() -> None:
    ####################
    # Definitions      #
    ####################

    dataset = "cm"
    force_creation = True
    verbose = True

    # Define input paths
    raw_seq_path = "data/raw/cm/cm.csv"
    pdb_dir = "data/raw/cm/pdb"
    json_path = f"data/processed/{dataset}/{dataset}.json"

    # Define output path
    out_csv_path = f"data/processed/cm/cm.csv"

    # Define GraphPart parameters
    alignment_mode = "needle"
    threads = 20
    threshold_inc = 0.05
    n_partitions = 3
    min_pp_split = 1 / (n_partitions + 1)
    initial_threshold = 0.35

    # Optional visualization
    show_splits = True

    ####################
    # Build DataFrame  #
    ####################
    if not os.path.exists(out_csv_path) or force_creation:
        print(f"Generating CSV.")
        # Read and filter input file
        df_seq = pd.read_csv(raw_seq_path)
        df = df_seq[["name", "sequence", "enzyme_activity", "comment"]].copy()
        # Binary colum, wildtype indicator (0 = artificial, 1 = natural)
        df["wildtype"] = 0
        df.loc[df["comment"] == "natural sequences", "wildtype"] = 1
        # Rename comment column
        for i in (0.33, 0.66, 1):
            df.loc[
                df["comment"] == f"bmDCA designed sequences, T={i}", "comment"
            ] = f"bmDCA (T={i})"
        df.loc[df["comment"] == f"natural sequences", "comment"] = f"natural"
        df.loc[
            df["comment"] == f"designed sequences by profile model", "comment"
        ] = f"profile model"

        df["indicator"] = np.nan
        for i, indicator in enumerate(
            [
                "natural",
                "bmDCA (T=0.33)",
                "bmDCA (T=0.66)",
                "bmDCA (T=1)",
                "profile model",
            ]
        ):
            df.loc[df["comment"] == indicator, "indicator"] = i
        df["indicator"] = df["indicator"].astype(int)
        # Alter names
        df = df.rename(columns={"enzyme_activity": "target_reg"})

        # Binary label active/inactive
        label = df["target_reg"] > 0.42  # See Russ, Figliuzzi et al. (2020)
        df["target_class"] = label.astype(int)

        # Save as csv file
        df.to_csv(out_csv_path, index_label="index")
        print(f"Created {out_csv_path}.\n")
    else:
        df = pd.read_csv(out_csv_path, index_col=0)
        print("Loaded CSV.")

    if "target_class_2" not in df:
        # Compute mean value of active cluster
        class_threshold = df.loc[df["target_class"].astype(bool), "target_reg"].mean()
        label = df["target_reg"] > class_threshold
        df["target_class_2"] = label.astype(int)
        df.to_csv(out_csv_path, index_label="index")

    #######################
    # Generate splits     #
    #######################

    threshold = np.nan
    if "part_0" not in df:
        df_cv = df[df["indicator"].isin([0, 1, 2])].reset_index()
        ckpt_path = f"data/processed/{dataset}/{dataset}_cv_graphpart_edges.csv"
        ids, threshold = generate_CV_partitions(
            df_cv,
            initial_threshold,
            dataset,
            alignment_mode,
            verbose,
            n_partitions,
            threads,
            min_pp_split,
            threshold_inc,
            ckpt_path,
        )
        # Partition headers
        headers = [f"part_{i}" for i in range(n_partitions)]

        for header, idx in zip(headers, ids):
            df_cv[header] = 0
            df_cv.loc[idx, header] = 1

        df_cv = df_cv.set_index("index")

        # Merge with df
        df = pd.concat((df_cv, df))
        df = df.drop_duplicates(subset=["name"], keep="first")
        df = df.sort_values(["indicator", "name"])
        df = df.reset_index(drop=True)

        # Create naive holdout partition
        df["holdout"] = ""
        df.loc[df["part_0"] == 1, "holdout"] = "train"
        df.loc[df["part_1"] == 1, "holdout"] = "val"
        df.loc[df["part_2"] == 1, "holdout"] = "test"

        df.to_csv(out_csv_path, index_label="index")

    if show_splits:
        show_CV_split_distribution(df, threshold, dataset, n_partitions)

    ####################
    # Create JSON      #
    # ####################

    if not os.path.exists(json_path):
        print("Generating JSON file.")
        # Store dictionary of sequences in list
        full_list = []
        # Iterate through sequences in dataframe
        for index, row in tqdm(df.iterrows(), total=df.shape[0]):
            # Get id to find PDB-file
            seq_id = row["name"]
            seq_len = len(row["sequence"])
            # Load and parse structure
            coords = pdb_to_coords(
                f"{pdb_dir}/{seq_id}_unrelaxed_rank_1_model_3.pdb",
                identifier=seq_id[7:],
                sequence_length=seq_len,
            )

            # Create and fill dictionary over sequence
            seq_struct = OrderedDict({"index": index})
            seq_struct.update(row.to_dict())
            seq_struct.update({"coords": coords.tolist()})

            full_list.append(seq_struct)

        # Create json string and save to specified location
        json_string = json.dumps(full_list)

        with open(json_path, "w") as f:
            f.write(json_string)

        print(f"Processed JSON-file saved as {json_path}.")

    return df


if __name__ == "__main__":
    df = parse_cm()
