import torch

from geo_ldm.init_latent_diffuser import get_latent_diffusion
from geo_ldm.regressor import EGNN_regressor
from geo_ldm.diffusion_guidance import DiffusionGuidanceModel

def get_diffusion_guidance(args, device, dataset_info, dataloader_train):
    assert args.ae_path is not None, "To train a regressor, we need a pretrained VAE"

    ldm, nodes_dist, prop_dist = get_latent_diffusion(args, device, dataset_info, dataloader_train)

    classifier_head = args.regression_target in ['morgan_fingerprint']
    # context_node_nf=0 because the regressor is not trained on context
    regressor = EGNN_regressor(
        in_node_nf=args.latent_nf, context_node_nf=0,
        n_dims=3, device=device, hidden_nf=args.hidden_nf_vae,
        act_fn=torch.nn.SiLU(), n_layers=args.n_layers,
        attention=args.attention, tanh=args.tanh, mode=args.model, norm_constant=args.norm_constant,
        inv_sublayers=args.inv_sublayers_vae, sin_embedding=args.sin_embedding,
        normalization_factor=args.normalization_factor, aggregation_method=args.aggregation_method,
        include_atomic_numbers=args.include_atomic_numbers, condition_time=args.condition_time_regressor,
        classifier_head=classifier_head, n_props=len(args.regression_target))

    diffusion_guidance = DiffusionGuidanceModel(
        regressor=regressor,
        ldm=ldm,
        n_dims=3,
        include_atomic_numbers=args.include_atomic_numbers,
        max_step=args.max_step_regressor,
    )
    if args.regression_target in ['morgan_fingerprint']:
        diffusion_guidance.compute_pos_weight(dataloader_train)

    return diffusion_guidance, nodes_dist, prop_dist