import os
import time

from ProtLig_GPCRclassA.datasets.M2OR.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR.split import *
from ProtLig_GPCRclassA.datasets.M2OR.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess_discard import *

from ProtLig_GPCRclassA.datasets.M2OR.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_class_dist import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_data_quality import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_weights import *



if __name__ == '__main__':
    # data_path = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR/M2OR_20231124_130000'
    data_path = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR_20250501_165200'
    # Preprocess pair:
    pairs = CV_ORpairs(base_working_dir = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data',
                        data_path = {'pairs': os.path.join(data_path, 'pairs.csv'),
                                    'main_receptors': os.path.join(data_path, 'main_receptors.csv'),
                                    'main_compounds': os.path.join(data_path, 'main_compounds.csv'),
                                    'compounds': os.path.join(data_path, 'compounds.csv'),
                                    })
    pairs.CV_data()
    print(pairs.working_dir)

    # -----------------------------------
    # Create screening confidence matrix:
    # -----------------------------------
    screening_confidence = ScreeningConfidence(base_working_dir = pairs.working_dir,
                                        data_path = {'experiments': os.path.join(data_path, 'experiments.csv'),
                                                    'references': os.path.join(data_path, 'references.csv'),
                                                    'main_receptors': os.path.join(data_path, 'main_receptors.csv')})
    screening_confidence.main()

    # ----------------
    # Seqs preprocess:
    # ----------------
    # Seqs preprocess:
    seqs_len_discrad = CVPP_LengthDiscard(data_dir = os.path.join(pairs.working_dir, 'seqs'), 
                       lower_bound = 296, 
                       upper_bound = 'Inf', 
                       seqs_csv = None, 
                       seq_id_col = 'seq_id', 
                       seq_col = 'mutated_sequence',
                       auxiliary_data_path = {'stats_data' : os.path.join(pairs.working_dir, 'full_data.csv')})
    seqs_len_discrad.postprocess()

    # ----------------
    # Mols preprocess:
    # ----------------
    # Mols preprocess - discrad:
    mols_discard_mixture = CVPP_MixtureDiscard(data_dir = os.path.join(pairs.working_dir, 'mols'), 
                        mols_csv = None, 
                        mol_id_col = 'mol_id',
                        auxiliary_data_path = {'stats_data' : os.path.join(pairs.working_dir, 'full_data.csv')})
    mols_discard_mixture.postprocess()

    mols_discard_by_list = CVPP_MoleculeDiscard(data_dir = os.path.join(pairs.working_dir, 'mols', 'discard_mix'),
                        mol_keys = ['ONKNPOPIGWHAQC-UHFFFAOYSA-N'],
                        mols_csv = os.path.join(pairs.working_dir, 'mols', 'discard_mix', 'mols_mix.csv'),
                        mol_id_col = 'mol_id', 
                        mol_key_col = 'inchi_key', 
                        auxiliary_data_path = {'stats_data' : os.path.join(pairs.working_dir, 'full_data.csv')})
    mols_discard_by_list.postprocess()

    # Mols preprocess - graph construction - racemic:
    mols_mix_racemic = CVPP_Mixture_Racemic(data_dir = mols_discard_by_list.working_dir,
                                    mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'),
                                    mol_id_col = 'mol_id',
                                    auxiliary_data_path = {'map_inchikey_to_isomericSMILES' : None,
                                                           'map_inchikey_to_canonicalSMILES' : None,
                                                            })
    mols_mix_racemic.postprocess()

    # Mols preprocess - size cut:
    print('Size cut - SMILES_racemic...')
    mols_sizecut_racemic = CVPP_SizeCut(data_dir = mols_discard_by_list.working_dir, 
                        n_node_thresholds = [32, 128], 
                        n_edge_thresholds = [64, 256], 
                        mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_col = 'SMILES_racemic', 
                        bond_multiplier = 2)
    mols_sizecut_racemic.postprocess()

    # Mols preprocess - graph construction - concatGraph:
    mols_mix_concatGraph = CVPP_Mixture_ConcatGraph(data_dir = mols_discard_by_list.working_dir,
                                    mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'),
                                    mol_id_col = 'mol_id',
                                    auxiliary_data_path = {'map_inchikey_to_isomericSMILES' : None,
                                                           'map_inchikey_to_canonicalSMILES' : None,
                                                            })
    mols_mix_concatGraph.postprocess()

    # Mols preprocess - size cut - concatGraph:
    print('Size cut - SMILES_concatGraph...')
    mols_sizecut_racemic = CVPP_SizeCut(data_dir = mols_discard_by_list.working_dir, 
                        n_node_thresholds = [32, 128], 
                        n_edge_thresholds = [64, 256], 
                        mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_col = 'SMILES_concatGraph', 
                        bond_multiplier = 2)
    mols_sizecut_racemic.postprocess()

    

    for _ in range(5):
        # ------
        # Split:
        # ------
        ec50_random = EC50_Random_test_and_valid(data_dir = pairs.working_dir,
                                seed = int(time.time()),
                                split_kwargs = {'valid_ratio' : 0.1, # This referes to the percentage of EC50
                                                'test_ratio' : 0.2, # This referes to the percentage of EC50
                                                }) 
        ec50_random.CV_split()

        # Class distribution:
        class_dist = CVPP_class_dist(data_dir = ec50_random.working_dir,
                                      seq_id_col = 'seq_id', 
                                      mol_id_col = 'mol_id',
                                      label_col = 'responsive',
                                      seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lower296_upperInf.csv'),
                                      mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'))
        class_dist.postprocess()

        addWeights_Class = CVPP_addWeights_Class(data_dir = ec50_random.working_dir,
                                                auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
        addWeights_Class.postprocess()

        # Broadness:
        broandess = CVPP_calculate_broadness(data_dir = ec50_random.working_dir,
                                                mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                                                seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lower296_upperInf.csv'), 
                                                mol_id_col = 'mol_id', 
                                                seq_id_col = 'seq_id',
                                                label_col = 'responsive')
        broandess.postprocess()

        root_broadness_clip_bias = CVPP_addWeights_RootBroadnessClipNoBias(data_dir = ec50_random.working_dir,
                                                    label_col = 'responsive',
                                                    mol_id_col = 'mol_id', 
                                                    seq_id_col = 'seq_id',
                                                    n_classes = 2,
                                                    seq_min_count = 100, 
                                                    mol_min_count = 50,
                                                    tested_enough_count = 50, 
                                                    seq_clip_min_broadness = 0.025, 
                                                    mol_clip_min_broadness = 0.025, 
                                                    auxiliary_data_path = {'seqs_broadness' : os.path.join(ec50_random.working_dir, 'seqs_broadness__mols_n1__seqs_lower296_upperInf.csv'),
                                                                           'mols_broadness' : os.path.join(ec50_random.working_dir, 'mols_broadness__mols_n1__seqs_lower296_upperInf.csv')})
        root_broadness_clip_bias.postprocess()

        # Data quality:
        test_data_quality = CVPP_addWeights_DataQuality(data_dir = ec50_random.working_dir,
                                                    label_col = 'responsive', 
                                                    data_quality_col = 'data_quality', 
                                                    auxiliary_data_path = {'screening_confidence_probs' : os.path.join(pairs.working_dir, 'Screening_confidence_probabilities.json'),
                                                                                    })
        test_data_quality.postprocess()

        # ----------------
        # Combine weights:
        # ----------------
        # Combine weights - class:
        combine_weights_class = CVPP_combineWeights(data_dir = ec50_random.working_dir, 
                                     weight_cols = ['class_weight', 'data_quality_weight'])
        combine_weights_class.postprocess()

        # Combine weights - root broandess:
        combine_weights_broadness = CVPP_combineWeights(data_dir = ec50_random.working_dir, 
                                     weight_cols = ['root_broadness_weight', 'data_quality_weight'])
        combine_weights_broadness.postprocess()