import torch

from qm9.models import DistributionNodes, DistributionProperty, DistributionFingerprint
from geo_ldm.encoder import EGNN_encoder
from geo_ldm.decoder import EGNN_decoder
from geo_ldm.vae import EnHierarchicalVAE


def get_vae(args, device, dataset_info, dataloader_train, recompute_class_weight=None):
    histogram = dataset_info['n_nodes']
    in_node_nf = len(dataset_info['atom_decoder']) + int(args.include_atomic_numbers) + len(dataset_info['formal_charges']) # one hot formal charges always included
    num_edge_types = 5 # no_edge, single, double, triple, aromatic
    nodes_dist = DistributionNodes(histogram)

    prop_dist = None
    if len(args.conditioning) > 0:
        if 'morgan_fingerprint' in args.conditioning:
            prop_dist = DistributionFingerprint(dataloader_train, args.conditioning)
        else:
            prop_dist = DistributionProperty(dataloader_train, args.conditioning)

    # if args.condition_time:
    #     dynamics_in_node_nf = in_node_nf + 1
    # else:
    print('Autoencoder models are _not_ conditioned on time.')
        # dynamics_in_node_nf = in_node_nf
    
    encoder = EGNN_encoder(
        in_node_nf=in_node_nf, context_node_nf=args.context_node_nf, out_node_nf=args.latent_nf,
        n_dims=3, device=device, hidden_nf=args.hidden_nf_vae,
        act_fn=torch.nn.SiLU(), n_layers=args.n_layers_encoder,
        attention=args.attention, tanh=args.tanh, mode=args.model, norm_constant=args.norm_constant,
        inv_sublayers=args.inv_sublayers_vae, sin_embedding=args.sin_embedding,
        normalization_factor=args.normalization_factor, aggregation_method=args.aggregation_method,
        include_atomic_numbers=args.include_atomic_numbers, n_extra_atomic_features=args.n_extra_atomic_features,
        encode_h_indep=args.encode_h_indep,
        )
    
    decoder = EGNN_decoder(
        in_node_nf=args.latent_nf, context_node_nf=args.context_node_nf, out_node_nf=in_node_nf,
        n_dims=3, device=device, hidden_nf=args.hidden_nf_vae,
        act_fn=torch.nn.SiLU(), n_layers=args.n_layers_decoder,
        attention=args.attention, tanh=args.tanh, mode=args.model, norm_constant=args.norm_constant,
        inv_sublayers=args.inv_sublayers_vae, sin_embedding=args.sin_embedding,
        normalization_factor=args.normalization_factor, aggregation_method=args.aggregation_method,
        include_atomic_numbers=args.include_atomic_numbers, num_edge_types=num_edge_types,
        use_rbf=args.use_rbf,
        )

    vae = EnHierarchicalVAE(
        encoder=encoder,
        decoder=decoder,
        in_node_nf=in_node_nf,
        n_dims=3,
        latent_node_nf=args.latent_nf,
        kl_weight=args.kl_weight,
        norm_values=args.normalize_factors,
        include_atomic_numbers=args.include_atomic_numbers,
        num_edge_types=num_edge_types,
        noise_sigma=args.noise_sigma_vae,
        use_focal_loss=args.use_focal_loss,
        n_formal_charges=len(dataset_info['formal_charges']),
        )
    if recompute_class_weight is None:
        recompute_class_weight = False
        # TODO: refactor this
        if args.use_ghost_nodes:
            recompute_class_weight = True
    vae.prepare_class_weights(dataset_info, device)

    return vae, nodes_dist, prop_dist
