import os
import time

from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.preprocess import *
from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.split import *
from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.seqs_postprocess import *
from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.postprocess_class_dist import *
from ProtLig_GPCRclassA.datasets.DAVIS_HyperAttentionDTI.postprocess_weights import *

if __name__ == '__main__':
    DAVIS_pairs = CV_pairs_HyperAttentionDTI(base_working_dir = os.path.join('amino_GNN', 'Data', 'DAVIS_HyperAttentionDTI'),
                        data_path = {'data' : '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/DAVIS_HyperAttentionDTI/Davis_HyperAttentionDTI.txt'})
    DAVIS_pairs.CV_data()
    
    # ----------------
    # Seqs preprocess:
    # ----------------
    # Seqs preprocess:
    seqs_len_discrad = CVPP_LengthDiscard(data_dir = os.path.join(DAVIS_pairs.working_dir, 'seqs'),
                       lower_bound = 'Inf',
                       upper_bound = 1736,
                       seqs_csv = None,
                       seq_id_col = 'Target_ID',
                       seq_col = 'Target',
                       auxiliary_data_path = {'stats_data' : os.path.join(DAVIS_pairs.working_dir, 'full_data.csv')})
    seqs_len_discrad.postprocess()
    
    # ----------------
    # Mols preprocess:
    # ----------------
    # Get size cut molecules:
    _sizecut_concatGraph = CVPP_SizeCut(data_dir = os.path.join(DAVIS_pairs.working_dir, 'mols'), 
                            n_node_thresholds = [32, 128], 
                            n_edge_thresholds = [64, 256], 
                            mols_csv = os.path.join(DAVIS_pairs.working_dir, 'mols', 'mols.csv'), 
                            mol_id_col = 'Drug_ID', 
                            mol_col = 'Drug', 
                            bond_multiplier = 2)
    _sizecut_concatGraph.postprocess()
    
    
    for _ in range(5):
        time.sleep(1)
        # Split data:
        split = Random(data_dir = DAVIS_pairs.working_dir,
                                seed = int(time.time()),
                                split_kwargs = {'valid_ratio' : 0.1,
                                                'test_ratio' : 0.2},
                                )
    
        split.CV_split()
    
        # Get broadness weights:
        broandess = CVPP_calculate_broadness(data_dir = split.working_dir,
                                                mols_csv = os.path.join(DAVIS_pairs.working_dir, 'mols', 'mols.csv'), 
                                                seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lowerInf_upper1736.csv'), 
                                                mol_id_col = 'Drug_ID', 
                                                seq_id_col = 'Target_ID',
                                                label_col = 'Responsive')
        broandess.postprocess()
    
        addWeights_broadness_clip_bias = CVPP_addWeights_BroadnessClipNoBias(data_dir = split.working_dir,
                                                    label_col = 'Responsive',
                                                    mol_id_col = 'Drug_ID', 
                                                    seq_id_col = 'Target_ID',
                                                    n_classes = 2,
                                                    seq_min_count = 50, 
                                                    mol_min_count = 50,
                                                    tested_enough_count = 50, 
                                                    seq_clip_min_broadness = 0.05, 
                                                    mol_clip_min_broadness = 0.05, 
                                                    auxiliary_data_path = {'seqs_broadness' : os.path.join(split.working_dir,'seqs_broadness__mols__seqs_lowerInf_upper1736.csv'),
                                                                           'mols_broadness' : os.path.join(split.working_dir,'mols_broadness__mols__seqs_lowerInf_upper1736.csv')})
        addWeights_broadness_clip_bias.postprocess()
    
        # Get class weights:
        class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                    seq_id_col = 'Target_ID', 
                                    mol_id_col = 'Drug_ID',
                                    label_col = 'Responsive',
                                    seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lowerInf_upper1736.csv'),
                                    mols_csv = os.path.join(DAVIS_pairs.working_dir, 'mols', 'mols.csv'))
        class_dist.postprocess()
    
        addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                                class_col = 'Responsive',
                                                auxiliary_data_path = {'class_dist' : os.path.join(split.working_dir, 'class_dist.json')})
        addWeights_Class.postprocess()