import os
import time
import pandas

from ProtLig_GPCRclassA.datasets.M2OR_concentration.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.split_mols_ood import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_create_maps import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_discard import *

# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_class_dist import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_data_quality import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_weights import *


if __name__ == '__main__':
    raw_data_path = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR_20250501_165200'
    data_dir = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data/m2or_conc_mixDiscard_20250501-165522'

    create_molecule_mapping = CreateMoleculeMapping(mols_csv = os.path.join(data_dir, 'mols', 'mols.csv'),
                                                    mol_id_col = 'mol_id',
                                                    working_dir = os.path.join(data_dir, 'mols'),
                                                    inchikey_col = 'inchi_key',
                                                    sep = ';')
    
    create_molecule_mapping.update_maps()

    num_splits = 5
    i = 0
    while i < num_splits:
        time.sleep(1)
        # ------
        # Split:
        # ------
        split = EC50_LeaveClusterOut_Mol_keep_screening(data_dir = data_dir,
                                                        seed = int(time.time()),
                                                        split_kwargs = {'valid_ratio' : 0.1, # Test ratio is given by the portion_mols or clustering.
                                                                        'min_n_ec50_ligands' : 2,
                                                                        'min_n_mols_with_enough_actives_per_cluster' : 2,
                                                                        'n_cluster_sample' : 5,
                                                                        'max_test_size' : 0.4,
                                                                        'min_test_size' : 0.12,
                                                                        'hdbscan_min_samples' : 1,
                                                                        'hdbscan_min_cluster_size' : 10,
                                                                        'mol_id_col' : 'mol_id',
                                                                        'auxiliary_data_path' : {'mols_csv' : os.path.join(data_dir, 'mols', 'discard_by_list_20250501-165622', 'mols_n1.csv'),
                                                                                                 'map_inchikey_to_canonicalSMILES' : os.path.join(data_dir, 'mols', 'map_inchikey_to_canonicalSMILES.json')}}) 
        try:
            split.CV_split()
            i+=1
        except InadequateTestSetSizeError as e:
            print('\nInadequateTestSetSizeError:   ', e, '\n')
            continue    

        # Class distribution:
        class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                    seq_id_col = 'seq_id', 
                                    mol_id_col = 'mol_id',
                                    label_col = 'responsive',
                                    seqs_csv = os.path.join(data_dir, 'seqs','discard_by_length', 'seqs_lower296_upperInf.csv'),
                                    mols_csv = os.path.join(data_dir, 'mols','discard_by_list_20250501-165622', 'mols_n1.csv'))
        class_dist.postprocess()

        addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                            auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
        addWeights_Class.postprocess()

        discard_screening = CVPP_Mol_discard_screening(data_dir = split.working_dir,
                                                  mol_id_col = 'mol_id')
        discard_screening.postprocess()