import pickle
from os.path import join

import torch

from geo_ldm.init_vae import get_vae
from geo_ldm.latent_diffuser import EnLatentDiffusion
from egnn.models import EGNN_dynamics_QM9
from geo_ldm.mol_optimizer import MoleculeOptimizer
from qm9.models import DistributionProperty, DistributionFingerprint


def get_latent_diffusion(args, device, dataset_info, dataloader_train, recompute_class_weight=None, mol_optimizer=False):

    # Create (and load) the first stage model (Autoencoder).
    if args.ae_path is not None:
        #with open(join(args.ae_path, 'args.pickle'), 'rb') as f:
        with open(join(args.ae_path, 'args_vae.pickle'), 'rb') as f:
            first_stage_args = pickle.load(f)
    else:
        first_stage_args = args
    
    # CAREFUL with this -->
    if not hasattr(first_stage_args, 'normalization_factor'):
        first_stage_args.normalization_factor = 1
    if not hasattr(first_stage_args, 'aggregation_method'):
        first_stage_args.aggregation_method = 'sum'
    if not hasattr(first_stage_args, 'n_layers_decoder'):
        first_stage_args.n_layers_decoder = first_stage_args.n_layers_vae
    if hasattr(first_stage_args, 'tanh_vae'):
        first_stage_args.tanh = first_stage_args.tanh_vae
    if not hasattr(first_stage_args, 'use_ghost_nodes'):
        first_stage_args.use_ghost_nodes = False
    if not hasattr(first_stage_args, 'use_rbf'):
        first_stage_args.use_rbf = False
    if not hasattr(first_stage_args, 'noise_sigma_vae'):
        first_stage_args.noise_sigma_vae = None
    if not hasattr(first_stage_args, 'n_layers_encoder'):
        first_stage_args.n_layers_encoder = 1
    if not hasattr(first_stage_args, 'n_extra_atomic_features'):
        first_stage_args.n_extra_atomic_features = 0
    if not hasattr(first_stage_args, 'use_focal_loss'):
        first_stage_args.use_focal_loss = False
    if not hasattr(first_stage_args, 'encode_h_indep'):
        first_stage_args.encode_h_indep = args.encode_h_indep if hasattr(args, 'encode_h_indep') else False

    device = torch.device("cuda" if first_stage_args.cuda else "cpu")

    first_stage_model, nodes_dist, prop_dist = get_vae(
        first_stage_args, device, dataset_info, dataloader_train, recompute_class_weight=recompute_class_weight)
    first_stage_model.to(device)

    if prop_dist is None and len(args.conditioning) > 0:
        # means diffusion model is conditioned but not vae
        if 'morgan_fingerprint' in args.conditioning:
            prop_dist = DistributionFingerprint(dataloader_train, args.conditioning)
        else:
            prop_dist = DistributionProperty(dataloader_train, args.conditioning)

    if args.ae_path is not None:
        #fn = 'generative_model_ema.npy' if first_stage_args.ema_decay > 0 else 'generative_model.npy'
        fn = 'vae_ema.npy' if first_stage_args.ema_decay > 0 else 'vae.npy'
        flow_state_dict = torch.load(join(args.ae_path, fn),
                                        map_location=device)
        first_stage_model.load_state_dict(flow_state_dict)

    # Create the second stage model (Latent Diffusions).
    args.latent_nf = first_stage_args.latent_nf
    in_node_nf = args.latent_nf

    if args.condition_time:
        dynamics_in_node_nf = in_node_nf + 1
    else:
        print('Warning: dynamics model is _not_ conditioned on time.')
        dynamics_in_node_nf = in_node_nf

    net_dynamics = EGNN_dynamics_QM9(
        in_node_nf=dynamics_in_node_nf, context_node_nf=args.context_node_nf,
        n_dims=3, device=device, hidden_nf=args.nf,
        act_fn=torch.nn.SiLU(), n_layers=args.n_layers,
        attention=args.attention, tanh=args.tanh, mode=args.model, norm_constant=args.norm_constant,
        inv_sublayers=args.inv_sublayers, sin_embedding=args.sin_embedding,
        normalization_factor=args.normalization_factor, aggregation_method=args.aggregation_method)

    if args.probabilistic_model == 'diffusion':
        geo_ldm_class = EnLatentDiffusion if not mol_optimizer else MoleculeOptimizer
        vdm = geo_ldm_class(
            vae=first_stage_model,
            trainable_ae=args.trainable_ae,
            dynamics=net_dynamics,
            in_node_nf=in_node_nf,
            n_dims=3,
            timesteps=args.diffusion_steps,
            noise_schedule=args.diffusion_noise_schedule,
            noise_precision=args.diffusion_noise_precision,
            loss_type=args.diffusion_loss_type,
            norm_values=args.normalize_factors,
            include_charges=args.include_atomic_numbers,
            joint_training=args.joint_training,
            use_eps_correction=args.use_eps_correction,
            joint_space=args.joint_space,
            lambda_joint_loss=args.lambda_joint_loss,
            loss_weighting=args.loss_weighting,
            )

        return vdm, nodes_dist, prop_dist

    else:
        raise ValueError(args.probabilistic_model)
