import os
import time

from ProtLig_GPCRclassA.datasets.M2OR.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR.split import *
from ProtLig_GPCRclassA.datasets.M2OR.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR.mols_postprocess_discard import *

from ProtLig_GPCRclassA.datasets.M2OR.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_class_dist import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_data_quality import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR.postprocess_weights import *



if __name__ == '__main__':
    data_path = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR/M2OR_20250330_141300'
    # Preprocess pair:
    pairs = CV_ORpairs(base_working_dir = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data',
                        data_path = {'pairs': os.path.join(data_path, 'pairs.csv'),
                                    'main_receptors': os.path.join(data_path, 'main_receptors.csv'),
                                    'main_compounds': os.path.join(data_path, 'main_compounds.csv'),
                                    'compounds': os.path.join(data_path, 'compounds.csv'),
                                    })
    pairs.CV_data()
    print(pairs.working_dir)

    # -----------------------------------
    # Create screening confidence matrix:
    # -----------------------------------
    screening_confidence = ScreeningConfidence(base_working_dir = pairs.working_dir,
                                        data_path = {'experiments': os.path.join(data_path, 'experiments.csv'),
                                                    'references': os.path.join(data_path, 'references.csv'),
                                                    'main_receptors': os.path.join(data_path, 'main_receptors.csv')})
    screening_confidence.main()

    # ----------------
    # Seqs preprocess:
    # ----------------
    # Seqs preprocess:
    seqs_len_discrad = CVPP_LengthDiscard(data_dir = os.path.join(pairs.working_dir, 'seqs'), 
                       lower_bound = 296, 
                       upper_bound = 'Inf', 
                       seqs_csv = None, 
                       seq_id_col = 'seq_id', 
                       seq_col = 'mutated_sequence',
                       auxiliary_data_path = {'stats_data' : os.path.join(pairs.working_dir, 'full_data.csv')})
    seqs_len_discrad.postprocess()

    # ----------------
    # Mols preprocess:
    # ----------------
    # Mols preprocess - discrad:
    mols_discard_mixture = CVPP_MixtureDiscard(data_dir = os.path.join(pairs.working_dir, 'mols'), 
                        mols_csv = None, 
                        mol_id_col = 'mol_id',
                        auxiliary_data_path = {'stats_data' : os.path.join(pairs.working_dir, 'full_data.csv')})
    mols_discard_mixture.postprocess()

    mols_discard_by_list = CVPP_MoleculeDiscard(data_dir = os.path.join(pairs.working_dir, 'mols', 'discard_mix'),
                        mol_keys = ['ONKNPOPIGWHAQC-UHFFFAOYSA-N'],
                        mols_csv = os.path.join(pairs.working_dir, 'mols', 'discard_mix', 'mols_mix.csv'),
                        mol_id_col = 'mol_id', 
                        mol_key_col = 'inchi_key', 
                        auxiliary_data_path = {'stats_data' : os.path.join(pairs.working_dir, 'full_data.csv')})
    mols_discard_by_list.postprocess()

    # Mols preprocess - graph construction:
    mols_mix_racemic = CVPP_Mixture_Racemic(data_dir = mols_discard_by_list.working_dir,
                                    mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'),
                                    mol_id_col = 'mol_id',
                                    auxiliary_data_path = {'map_inchikey_to_isomericSMILES' : None,
                                                           'map_inchikey_to_canonicalSMILES' : None,
                                                            })
    mols_mix_racemic.postprocess()

    # Mols preprocess - size cut:
    print('SMILES_racemic...')
    mols_sizecut_racemic = CVPP_SizeCut(data_dir = mols_discard_by_list.working_dir, 
                        n_node_thresholds = [32, 128], 
                        n_edge_thresholds = [64, 256], 
                        mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_col = 'SMILES_racemic', 
                        bond_multiplier = 2)
    mols_sizecut_racemic.postprocess()

    # ------
    # Split:
    # ------
    # NOTE: No split!
    split = TrainOnly(data_dir = pairs.working_dir,
                                seed = int(time.time()),
                                split_kwargs = {'valid_ratio' : None, # Number of measurements in validation set: valid_ratio*n_ec50
                                                'test_ratio' : None, # This referes to the percentage of EC50
                                                }) 
    split.CV_split()

    # Class distribution:
    class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                  seq_id_col = 'seq_id', 
                                  mol_id_col = 'mol_id',
                                  label_col = 'responsive',
                                  seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lower296_upperInf.csv'),
                                  mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'))
    class_dist.postprocess()

    addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                            auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
    addWeights_Class.postprocess()

    # Broadness:
    broandess = CVPP_calculate_broadness(data_dir = split.working_dir,
                                            mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                                            seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lower296_upperInf.csv'), 
                                            mol_id_col = 'mol_id', 
                                            seq_id_col = 'seq_id',
                                            label_col = 'responsive')
    broandess.postprocess()

    broadness_clip_bias = CVPP_addWeights_BroadnessClipNoBias(data_dir = split.working_dir,
                                                label_col = 'responsive',
                                                mol_id_col = 'mol_id', 
                                                seq_id_col = 'seq_id',
                                                n_classes = 2,
                                                seq_min_count = 100, 
                                                mol_min_count = 50,
                                                tested_enough_count = 50, 
                                                seq_clip_min_broadness = 0.05, 
                                                mol_clip_min_broadness = 0.05, 
                                                auxiliary_data_path = {'seqs_broadness' : os.path.join(split.working_dir, 'seqs_broadness__mols_n1__seqs_lower296_upperInf.csv'),
                                                                       'mols_broadness' : os.path.join(split.working_dir, 'mols_broadness__mols_n1__seqs_lower296_upperInf.csv')})
    broadness_clip_bias.postprocess()


    # Data quality:
    test_data_quality = CVPP_addWeights_DataQuality(data_dir = split.working_dir,
                                                label_col = 'responsive', 
                                                data_quality_col = 'data_quality', 
                                                auxiliary_data_path = {'screening_confidence_probs' : os.path.join(pairs.working_dir, 'Screening_confidence_probabilities.json'),
                                                                                })
    test_data_quality.postprocess()

    # ----------------
    # Combine weights:
    # ----------------
    # Combine weights - class:
    combine_weights_class = CVPP_combineWeights(data_dir = split.working_dir, 
                                 weight_cols = ['class_weight', 'data_quality_weight'])
    combine_weights_class.postprocess()

    # Combine weights - broandess:
    combine_weights_broadness = CVPP_combineWeights(data_dir = split.working_dir, 
                                 weight_cols = ['broadness_weight', 'data_quality_weight'])
    combine_weights_broadness.postprocess()