import os
import time

from ProtLig_GPCRclassA.datasets.M2OR_concentration.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.split import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_discard import *

# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_class_dist import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_data_quality import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_weights import *


if __name__ == '__main__':
    data_path = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR_20250501_165200'
    # Preprocess experiments:
    experiments = CV_ORconc_DiscardMix(base_working_dir = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data',
                        data_path = {'experiments': os.path.join(data_path,'experiments.csv'),
                                    'main_receptors': os.path.join(data_path,'main_receptors.csv'),
                                    'main_compounds': os.path.join(data_path,'main_compounds.csv'),
                                    'compounds': os.path.join(data_path,'compounds.csv'),
                                    })
    experiments.CV_data()
    print(experiments.working_dir)

    # ----------------
    # Seqs preprocess:
    # ----------------
    # Seqs preprocess:
    seqs_len_discrad = CVPP_LengthDiscard(data_dir = os.path.join(experiments.working_dir, 'seqs'),
                       lower_bound = 296, 
                       upper_bound = 'Inf', 
                       seqs_csv = None, 
                       seq_id_col = 'seq_id', 
                       seq_col = 'mutated_sequence',
                       auxiliary_data_path = {'stats_data' : os.path.join(experiments.working_dir, 'full_data.csv')})
    seqs_len_discrad.postprocess()
    
    # ----------------
    # Mols preprocess:
    # ----------------
    # Mols preprocess - discrad:
    mols_discard_by_list = CVPP_MoleculeDiscard(data_dir = os.path.join(experiments.working_dir, 'mols'), 
                        mol_keys = ['ONKNPOPIGWHAQC-UHFFFAOYSA-N'],
                        mols_csv = os.path.join(experiments.working_dir, 'mols', 'mols.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_key_col = 'inchi_key', 
                        auxiliary_data_path = {'stats_data' : os.path.join(experiments.working_dir, 'full_data.csv')})
    mols_discard_by_list.postprocess()

    # Mols preprocess - graph construction:
    mix_racemic = CVPP_Mixture_Racemic(data_dir = mols_discard_by_list.working_dir,
                                    mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'),
                                    mol_id_col = 'mol_id',
                                    auxiliary_data_path = {'map_inchikey_to_isomericSMILES' : None,
                                                           'map_inchikey_to_canonicalSMILES' : None,
                                                            })
    mix_racemic.postprocess()

    # Mols preprocess - size cut:
    mols_sizecut_racemic = CVPP_SizeCut(data_dir = mols_discard_by_list.working_dir, 
                        n_node_thresholds = [32, 128],
                        n_edge_thresholds = [64, 256], 
                        mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_col = 'SMILES_racemic', 
                        bond_multiplier = 2)
    mols_sizecut_racemic.postprocess()

    for _ in range(5):
        time.sleep(1)
        # ------
        # Split:
        # ------
        # split = Random_split(data_dir = experiments.working_dir,
        #                     seed = int(time.time()),
        #                     split_kwargs = {'valid_ratio' : 0.1,
        #                                     'test_ratio' : 0.1,
        #                                     }) 
        # split.CV_split()

        # split = EC50_Random(data_dir = experiments.working_dir,
        #                     seed = int(time.time()),
        #                     split_kwargs = {'valid_ratio' : 0.8, # Number of measurements in validation set: valid_ratio*n_ec50
        #                                     'test_ratio' : 0.3, # This referes to the percentage of EC50
        #                                     }) 
        # split.CV_split()

        split = EC50_Random_valid_only(data_dir = experiments.working_dir,
                            seed = int(time.time()),
                            split_kwargs = {'valid_ratio' : 0.1, # Number of measurements in validation set: valid_ratio*n_ec50
                                            'test_ratio' : None, # This referes to the percentage of EC50
                                            }) 
        split.CV_split()

        # Class distribution:
        class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                    seq_id_col = 'seq_id', 
                                    mol_id_col = 'mol_id',
                                    label_col = 'responsive',
                                    seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lower296_upperInf.csv'),
                                    mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'))
        class_dist.postprocess()

        addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                            auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
        addWeights_Class.postprocess()