import os
import time

from ProtLig_GPCRclassA.datasets.M2OR_concentration.preprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.split import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.screening_confidence import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.seqs_postprocess import *

from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_mixture import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.mols_postprocess_discard import *

# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_broadness import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_class_dist import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_data_quality import *
# from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_testing_weights import *
from ProtLig_GPCRclassA.datasets.M2OR_concentration.postprocess_weights import *


if __name__ == '__main__':
    # data_path = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/RawData/M2OR/M2OR_20250330_141300'
    # Preprocess experiments:
    experiments = CV_ORconc_DiscardMix(base_working_dir = '/mnt/ProtLig_GPCRclassA/ProtLig_GPCRclassA/amino_GNN/Data',
                        data_path = {'experiments': os.path.join(data_path,'experiments.csv'),
                                    'main_receptors': os.path.join(data_path,'main_receptors.csv'),
                                    'main_compounds': os.path.join(data_path,'main_compounds.csv'),
                                    'compounds': os.path.join(data_path,'compounds.csv'),
                                    })
    experiments.CV_data()
    print(experiments.working_dir)

    # ----------------
    # Seqs preprocess:
    # ----------------
    # Seqs preprocess:
    seqs_len_discrad = CVPP_LengthDiscard(data_dir = os.path.join(experiments.working_dir, 'seqs'),
                       lower_bound = 296, 
                       upper_bound = 'Inf', 
                       seqs_csv = None, 
                       seq_id_col = 'seq_id', 
                       seq_col = 'mutated_sequence',
                       auxiliary_data_path = {'stats_data' : os.path.join(experiments.working_dir, 'full_data.csv')})
    seqs_len_discrad.postprocess()
    
    # ----------------
    # Mols preprocess:
    # ----------------
    # Mols preprocess - discrad:
    mols_discard_by_list = CVPP_MoleculeDiscard(data_dir = os.path.join(experiments.working_dir, 'mols'), 
                        mol_keys = ['ONKNPOPIGWHAQC-UHFFFAOYSA-N'],
                        mols_csv = os.path.join(experiments.working_dir, 'mols', 'mols.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_key_col = 'inchi_key', 
                        auxiliary_data_path = {'stats_data' : os.path.join(experiments.working_dir, 'full_data.csv')})
    mols_discard_by_list.postprocess()

    # Mols preprocess - graph construction:
    mix_racemic = CVPP_Mixture_Racemic(data_dir = mols_discard_by_list.working_dir,
                                    mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'),
                                    mol_id_col = 'mol_id',
                                    auxiliary_data_path = {'map_inchikey_to_isomericSMILES' : None,
                                                           'map_inchikey_to_canonicalSMILES' : None,
                                                            })
    mix_racemic.postprocess()

    # Mols preprocess - size cut:
    mols_sizecut_racemic = CVPP_SizeCut(data_dir = mols_discard_by_list.working_dir, 
                        n_node_thresholds = [32, 128],
                        n_edge_thresholds = [64, 256], 
                        mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'), 
                        mol_id_col = 'mol_id', 
                        mol_col = 'SMILES_racemic', 
                        bond_multiplier = 2)
    mols_sizecut_racemic.postprocess()

    # ------
    # Split:
    # ------
    split = TrainOnly(data_dir = experiments.working_dir,
                        seed = int(time.time()),
                        split_kwargs = {'valid_ratio' : None, # Number of measurements in validation set: valid_ratio*n_ec50
                                        'test_ratio' : None, # This referes to the percentage of EC50
                                        }) 
    split.CV_split()

    # Class distribution:
    class_dist = CVPP_class_dist(data_dir = split.working_dir,
                                seq_id_col = 'seq_id', 
                                mol_id_col = 'mol_id',
                                label_col = 'responsive',
                                seqs_csv = os.path.join(seqs_len_discrad.working_dir, 'seqs_lower296_upperInf.csv'),
                                mols_csv = os.path.join(mols_discard_by_list.working_dir, 'mols_n1.csv'))
    class_dist.postprocess()

    addWeights_Class = CVPP_addWeights_Class(data_dir = split.working_dir,
                                        auxiliary_data_path = {'class_dist' : os.path.join(class_dist.working_dir, 'class_dist.json')})
    addWeights_Class.postprocess()