import os
import time

from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.preprocess import *
from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.split import *
from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.seqs_postprocess import *
from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.postprocess_class_dist import *
from ProtLig_GPCRclassA.datasets.KIBA_HyperAttentionDTI.postprocess_weights import *

if __name__ == '__main__':
    KIBA_pairs = CV_pairs_HyperAttentionDTI(base_working_dir = os.path.join('amino_GNN', 'Data', 'KIBA_HyperAttentionDTI'),
                        data_path = {'data' : '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/KIBA_HyperAttentionDTI/KIBA_HyperAttentionDTI.txt'})
    KIBA_pairs.CV_data()

    # ----------------
    # Seqs preprocess:
    # ----------------
    # Seqs preprocess:
    seqs_len_discrad = CVPP_LengthDiscard(data_dir = os.path.join(KIBA_pairs.working_dir, 'seqs'),
                       lower_bound = 'Inf',
                       upper_bound = 1408,
                       seqs_csv = None,
                       seq_id_col = 'Target_ID',
                       seq_col = 'Target',
                       auxiliary_data_path = {'stats_data' : os.path.join(KIBA_pairs.working_dir, 'full_data.csv')})
    seqs_len_discrad.postprocess()

    # ----------------
    # Mols preprocess:
    # ----------------
    # Get size cut molecules:
    _sizecut_concatGraph = CVPP_SizeCut(data_dir = os.path.join(KIBA_pairs.working_dir, 'mols'), 
                            n_node_thresholds = [32, 128], 
                            n_edge_thresholds = [64, 256], 
                            mols_csv = os.path.join(KIBA_pairs.working_dir, 'mols', 'mols.csv'), 
                            mol_id_col = 'Drug_ID', 
                            mol_col = 'Drug', 
                            bond_multiplier = 2)
    _sizecut_concatGraph.postprocess()


    for _ in range(5):
        time.sleep(1)
        # Split data:
        split = Random(data_dir = KIBA_pairs.working_dir,
                    seed = int(time.time()),
                    split_kwargs = {'valid_ratio' : 0.1,
                                    'test_ratio' : 0.2},
                    )

        split.CV_split()

        # Get broadness weights:
        broandess = CVPP_calculate_broadness(data_dir = split.working_dir,
                                        mols_csv = os.path.join(KIBA_pairs.working_dir, 'mols', 'mols.csv'),
                                        seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lowerInf_upper1408.csv'),
                                        mol_id_col = 'Drug_ID',
                                        seq_id_col = 'Target_ID',
                                        label_col = 'Responsive')
        broandess.postprocess()

        addWeights_broadness_clip_bias = CVPP_addWeights_BroadnessClipNoBias(data_dir = split.working_dir,
                                                    label_col = 'Responsive',
                                                    mol_id_col = 'Drug_ID', 
                                                    seq_id_col = 'Target_ID',
                                                    n_classes = 2,
                                                    seq_min_count = 50, 
                                                    mol_min_count = 50,
                                                    tested_enough_count = 50, 
                                                    seq_clip_min_broadness = 0.05, 
                                                    mol_clip_min_broadness = 0.05, 
                                                    auxiliary_data_path = {'seqs_broadness' : os.path.join(split.working_dir,'seqs_broadness__mols__seqs_lowerInf_upper1408.csv'),
                                                                           'mols_broadness' : os.path.join(split.working_dir,'mols_broadness__mols__seqs_lowerInf_upper1408.csv')})
        addWeights_broadness_clip_bias.postprocess()

        # Get class weights:
        class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                          seq_id_col = 'Target_ID', 
                                          mol_id_col = 'Drug_ID',
                                          label_col = 'Responsive',
                                          seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lowerInf_upper1408.csv'),
                                          mols_csv = os.path.join(KIBA_pairs.working_dir, 'mols', 'mols.csv'))
        class_dist.postprocess()

        addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                                      class_col = 'Responsive',
                                                auxiliary_data_path = {'class_dist' : os.path.join(os.path.join(split.working_dir, 'class_dist.json'))})
        addWeights_Class.postprocess()