import os
import time
import pandas

from ProtLig_GPCRclassA.datasets.M2OR_concentration.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.split_seqs_ood import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.seqs_blast_distance import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_discard import *

# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_class_dist import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_data_quality import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_weights import *


if __name__ == '__main__':
    raw_data_path = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR_20250501_165200'
    data_dir = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data/m2or_conc_mixDiscard_20250501-165522'

    num_splits = 5
    i = 0
    while i < num_splits:
        time.sleep(1)
        # ------
        # Split:
        # ------
        split = EC50_LeaveClusterOut_OR_keep_screening(data_dir = data_dir,
                        seed = int(time.time()),
                        split_kwargs = {'valid_ratio' : 0.1,
                                        'min_n_ec50_ligands' : 2,
                                        'min_n_seqs_with_enough_ligands_per_cluster' : 2,
                                        'n_cluster_sample' : 5,
                                        'max_test_size' : 0.4,
                                        'min_test_size' : 0.12,
                                        'seq_id_col' : 'seq_id',
                                        'hdbscan_min_samples' : 1,
                                        'hdbscan_min_cluster_size' : 15,
                                        'auxiliary_data_path' : {'seqs_csv' : os.path.join(data_dir, 'seqs','discard_by_length','seqs_lower296_upperInf.csv'),
                                                                'seqs_blast_similarity_csv' : os.path.join(data_dir, 'seqs','discard_by_length','seqs_lower296_upperInf_similarity.csv')}
                                               })
        try:
            split.CV_split()
            i+=1
        except InadequateTestSetSizeError as e:
            print('\nInadequateTestSetSizeError:   ', e, '\n')
            continue

        # NOTE: Class distribution is using the screening data.
        # Class distribution:
        class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                    seq_id_col = 'seq_id', 
                                    mol_id_col = 'mol_id',
                                    label_col = 'responsive',
                                    seqs_csv = os.path.join(data_dir, 'seqs','discard_by_length', 'seqs_lower296_upperInf.csv'),
                                    mols_csv = os.path.join(data_dir, 'mols','discard_by_list_20250501-165622', 'mols_n1.csv'))
        class_dist.postprocess()

        addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                            auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
        addWeights_Class.postprocess()

        discard_screening = CVPP_OR_discard_screening(data_dir = split.working_dir,
                                                      seq_id_col = 'seq_id')
        discard_screening.postprocess()