from . import *

import sys
import os

# Determine the absolute path to the external folder
current_directory = os.path.dirname(os.path.abspath(__file__))
external_directory = os.path.abspath(os.path.join(current_directory, '../data'))

# Add the external folder to sys.path
sys.path.append(external_directory)

# Now you can import the external module
from data.dataset_loading import load_datasets, create_dataloader
from model import init_ldm_model, init_diff_pro_sdf, Diffpro_SDF


class LdmTrainConfig(TrainConfig):

    def __init__(self, params, output_dir, debug_mode=False, load_chkpt_from=None) -> None:
        super().__init__(params, output_dir)
        self.debug_mode = debug_mode

        # create model
        self.ldm_model = init_ldm_model(params, debug_mode)
        if load_chkpt_from is not None:
            self.model = Diffpro_SDF.load_trained(self.ldm_model, load_chkpt_from).to(self.device)
        else:
            self.model = init_diff_pro_sdf(self.ldm_model, params, self.device)

        # Create dataloader
        train_set = load_datasets()
        self.train_dl = create_dataloader(params.batch_size, train_set)
        self.val_dl = create_dataloader(params.batch_size, train_set) # we temporarily use train_set for validation

        # Create optimizer4
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=params.learning_rate
        )
