import json
import os
import os.path
from collections import OrderedDict

import numpy as np
import pandas as pd
from Bio import SeqIO
from sklearn.mixture import GaussianMixture
from src.data.data_utils import generate_CV_partitions, pdb_to_coords
from src.visualization.visualization_funcs import show_CV_split_distribution
from tqdm import tqdm


def parse_tim() -> pd.DataFrame:
    ####################
    # Definitions      #
    ####################

    dataset = "tim"
    force_creation = True

    # Directory with TIM barrel domain structures
    pdb_dir = f"data/raw/tim/pdb_dom"
    # File containing full protein sequences and corresponding target values
    raw_path = f"data/raw/tim/meltome_w_topt.tsv"
    dom_fasta_path = "data/raw/tim/tim_dom.fasta"
    dom_msa_path = "data/processed/tim/tim_local.aln.fasta"
    dom_family_fasta_path = "data/raw/tim/tim_family.aln.fasta"

    # Define output paths
    out_csv_path = f"data/processed/tim/tim.csv"
    json_path = f"data/processed/tim/tim.json"

    # Define GraphPart parameters
    alignment_mode = "needle"
    threads = 10
    initial_threshold = 0.5
    n_partitions = 3
    min_pp_split = 1 / (n_partitions + 1)
    threshold_inc = 0.025

    verbose = True

    # Optional visualization
    show_splits = True

    #####################
    # Build DataFrame   #
    #####################

    if not os.path.exists(out_csv_path) or force_creation:
        print(f"Generating CSV.")
        # Read full meltome dataset (sequences + target values)
        df = pd.read_csv(raw_path, sep="\t")
        df = df.rename(
            columns={"key": "name", "tm": "target_reg", "sequence": "full_sequence"}
        )
        df = df[["name", "full_sequence", "target_reg"]]
        # Choose sequences present in both pdb, local alignment, and global alignment
        # Extract names of domains with structures
        pdb_names = os.listdir(pdb_dir)
        pdb_names = [name[:-29] for name in pdb_names]
        # Extract names of aligned domains
        msa_names = []
        for seq in SeqIO.parse(dom_msa_path, "fasta"):
            msa_names.append(seq.id)
        # Extract names of aligned domains in family
        fam_names = []
        for seq in SeqIO.parse(dom_family_fasta_path, "fasta"):
            fam_names.append(seq.id)

        # Use intersection of the three as dataset
        names = list(set(pdb_names) & set(msa_names) & set(fam_names))
        df = df.loc[df["name"].isin(names)]
        # Discard protein with sequence length > 800
        df = df[df["full_sequence"].str.len() < 800].reset_index(drop=True)

        # Add domain sequence
        df["sequence"] = ""
        names = df["name"].tolist()
        for seq in SeqIO.parse(dom_fasta_path, "fasta"):
            if seq.id in names:
                assert (
                    str(seq.seq) in df.loc[df["name"] == seq.id, "full_sequence"].item()
                )
                df.loc[df["name"] == seq.id, "sequence"] = str(seq.seq)
        # Fit GMM for binary classification
        labels = GaussianMixture(n_components=2, random_state=0).fit_predict(
            X=df["target_reg"].values.reshape(-1, 1)
        )

        df["target_class"] = labels.astype(int)

        df.to_csv(out_csv_path, index_label="index")
        print(f"Created {out_csv_path}.")
        # csv_to_fasta(file=df, target_path=fasta_path, stratify=False)
    else:
        df = pd.read_csv(out_csv_path, index_col=0)
        print("Loaded CSV.")

    #######################
    # Generate splits     #
    #######################
    threshold = np.nan
    if "part_0" not in df:
        ids, threshold = generate_CV_partitions(
            df,
            initial_threshold,
            dataset,
            alignment_mode,
            verbose,
            n_partitions,
            threads,
            min_pp_split,
            threshold_inc,
        )
        # Partition headers
        headers = [f"part_{i}" for i in range(n_partitions)]

        for header, idx in zip(headers, ids):
            df[header] = 0
            df.loc[idx, header] = 1

        # Create naive holdout partition
        df["holdout"] = ""
        df.loc[df["part_0"] == 1, "holdout"] = "train"
        df.loc[df["part_1"] == 1, "holdout"] = "val"
        df.loc[df["part_2"] == 1, "holdout"] = "test"

        df.to_csv(out_csv_path, index_label="index")

    if show_splits:
        show_CV_split_distribution(df, threshold, dataset, n_partitions)

        ####################
        # Create JSON      #
        # ####################

        if not os.path.exists(json_path):
            print("Generating JSON file.")
            # Store dictionary of sequences in list
            full_list = []
            # Iterate through sequences in dataframe
            for index, row in tqdm(df.iterrows(), total=df.shape[0]):
                # Get id to find PDB-file
                seq_id = row["name"]
                seq_len = len(row["sequence"])
                # Load and parse structure
                coords = pdb_to_coords(
                    f"{pdb_dir}/{seq_id}_unrelaxed_rank_1_model_3.pdb",
                    identifier=seq_id[7:],
                    sequence_length=seq_len,
                )

                # Create and fill dictionary over sequence
                seq_struct = OrderedDict({"index": index})
                seq_struct.update(row.to_dict())
                seq_struct.update({"coords": coords.tolist()})

                full_list.append(seq_struct)

            # Create json string and save to specified location
            json_string = json.dumps(full_list)

            with open(json_path, "w") as f:
                f.write(json_string)

            print(f"Processed JSON-file saved as {json_path}.")

    return df


if __name__ == "__main__":
    df = parse_tim()
