import os
import time
import pandas

from ProtLig_GPCRclassA.datasets.M2OR.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR.split import *
from ProtLig_GPCRclassA.datasets.M2OR.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess_discard import *

from ProtLig_GPCRclassA.datasets.M2OR.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_class_dist import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_data_quality import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_weights import *

from ProtLig_GPCRclassA.datasets.M2OR.split_mols_ood import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_create_maps import *


if __name__ == '__main__':
    raw_data_path = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR_20250501_165200'
    data_dir = '/data_mount/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data/m2or_pairs_20250514-125855'

    create_molecule_mapping = CreateMoleculeMapping(mols_csv = os.path.join(data_dir, 'mols', 'mols.csv'),
                                                    mol_id_col = 'mol_id',
                                                    working_dir = os.path.join(data_dir, 'mols'),
                                                    inchikey_col = 'inchi_key',
                                                    sep = ';')
    
    create_molecule_mapping.update_maps()

    num_splits = 5
    i = 0
    while i < num_splits:
        time.sleep(1)
        # ------
        # Split:
        # ------
        split = EC50_LeaveClusterOut_Mol_keep_screening(data_dir = data_dir,
                                                        seed = int(time.time()),
                                                        split_kwargs = {'valid_ratio' : 0.1, # Test ratio is given by the portion_mols or clustering.
                                                                        'min_n_ec50_ligands' : 2,
                                                                        'min_n_mols_with_enough_actives_per_cluster' : 2,
                                                                        'n_cluster_sample' : 5,
                                                                        'max_test_size' : 0.4,
                                                                        'min_test_size' : 0.12,
                                                                        'hdbscan_min_samples' : 1,
                                                                        'hdbscan_min_cluster_size' : 10,
                                                                        'mol_id_col' : 'mol_id',
                                                                        'auxiliary_data_path' : {'mols_csv' : os.path.join(data_dir, 'mols', 'discard_mix', 'discard_by_list_20250514-125915', 'mols_n1.csv'),
                                                                                                 'map_inchikey_to_canonicalSMILES' : os.path.join(data_dir, 'mols', 'map_inchikey_to_canonicalSMILES.json')}}) 
        try:
            split.CV_split()
            i+=1
        except InadequateTestSetSizeError as e:
            print('\nInadequateTestSetSizeError:   ', e, '\n')
            continue    

        # Class distribution:
        class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                    seq_id_col = 'seq_id', 
                                    mol_id_col = 'mol_id',
                                    label_col = 'responsive',
                                    seqs_csv = os.path.join(data_dir, 'seqs', 'discard_by_length', 'seqs_lower296_upperInf.csv'),
                                    mols_csv = os.path.join(data_dir, 'mols', 'discard_mix', 'discard_by_list_20250514-125915', 'mols_n1.csv'))
        class_dist.postprocess()

        addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                            auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
        addWeights_Class.postprocess()

        # Broadness:
        broandess = CVPP_calculate_broadness(data_dir = split.working_dir,
                                                mols_csv = os.path.join(data_dir, 'mols', 'discard_mix', 'discard_by_list_20250514-125915', 'mols_n1.csv'), 
                                                seqs_csv = os.path.join(data_dir, 'seqs', 'discard_by_length', 'seqs_lower296_upperInf.csv'), 
                                                mol_id_col = 'mol_id', 
                                                seq_id_col = 'seq_id',
                                                label_col = 'responsive')
        broandess.postprocess()

        root_broadness_clip_bias = CVPP_addWeights_RootBroadnessClipNoBias(data_dir = split.working_dir,
                                                    label_col = 'responsive',
                                                    mol_id_col = 'mol_id', 
                                                    seq_id_col = 'seq_id',
                                                    n_classes = 2,
                                                    seq_min_count = 100, 
                                                    mol_min_count = 50,
                                                    tested_enough_count = 50, 
                                                    seq_clip_min_broadness = 0.025, 
                                                    mol_clip_min_broadness = 0.025, 
                                                    auxiliary_data_path = {'seqs_broadness' : os.path.join(split.working_dir, 'seqs_broadness__mols_n1__seqs_lower296_upperInf.csv'),
                                                                           'mols_broadness' : os.path.join(split.working_dir, 'mols_broadness__mols_n1__seqs_lower296_upperInf.csv')})
        root_broadness_clip_bias.postprocess()

        # Data quality:
        data_quality = CVPP_addWeights_DataQuality(data_dir = split.working_dir,
                                                    label_col = 'responsive', 
                                                    data_quality_col = 'data_quality', 
                                                    auxiliary_data_path = {'screening_confidence_probs' : os.path.join(data_dir, 'Screening_confidence_probabilities.json'),
                                                                                    })
        data_quality.postprocess()

        # ----------------
        # Combine weights:
        # ----------------
        # Combine weights - class:
        combine_weights_class = CVPP_combineWeights(data_dir = split.working_dir, 
                                     weight_cols = ['class_weight', 'data_quality_weight'])
        combine_weights_class.postprocess()

        # Combine weights - root broandess:
        combine_weights_broadness = CVPP_combineWeights(data_dir = split.working_dir, 
                                     weight_cols = ['root_broadness_weight', 'data_quality_weight'])
        combine_weights_broadness.postprocess()

        # Discard screening:
        discard_screening = CVPP_Mol_discard_screening(data_dir = split.working_dir,
                                                  mol_id_col = 'mol_id')
        discard_screening.postprocess()