import pickle
from os.path import join

import torch

from configs.datasets_config import get_dataset_info
from qm9 import dataset
from qm9.models import get_model
from qm9.utils import prepare_context, compute_mean_mad, compute_properties_upper_bounds
from geo_ldm.init_latent_diffuser import get_latent_diffusion
from geo_ldm.init_diffusion_guidance import get_diffusion_guidance
from qm9.models import DistributionProperty


def load_ldm_model(exp_folder: str, mol_optimizer=False, ckpt_prefix=''):
    """
    Args:
        exp_folder (str): path to folder containing model checkpoints. e.g. outputs/edm_zinc/
        ckpt_prefix (str): prefix that specifies which model to load. '' for the best model in terms of score,
                            'best_fcd_' for the best fcd scoring model, 'last_' for the latest model
    """
    model_name = 'diffusion_model'
    with open(join(exp_folder, f'{ckpt_prefix}args_{model_name}.pickle'), 'rb') as f:
        args = pickle.load(f)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Retrieve QM9 dataloaders
    dataloaders, charge_scale = dataset.retrieve_dataloaders(args)
    dataset_info = get_dataset_info(args.dataset, args.remove_h)

    # Set up conditioning
    if len(args.conditioning) > 0:
        print(f'Conditioning on {args.conditioning}')
        data_dummy = next(iter(dataloaders['train']))
        property_norms = compute_mean_mad(dataloaders, args.conditioning, args.dataset)
        context_dummy = prepare_context(args.conditioning, data_dummy, property_norms, condition_dropout=args.condition_dropout)
        context_node_nf = context_dummy.size(2)
        property_upperbounds = compute_properties_upper_bounds(dataloaders['train'], args.conditioning, log_to_wandb=False)
        args.context_node_nf = context_node_nf
    else:
        property_norms = None

    # Create LDM
    model, nodes_dist, prop_dist = get_latent_diffusion(args, device, dataset_info, dataloaders['train'], mol_optimizer=mol_optimizer)
    # create prop_dist anyways
    # TODO: fix this
    if prop_dist is None:
        print('Warning: creating prop_dist anyways. FIX THIS!')
        prop_dist = DistributionProperty(dataloaders['train'], ['penalized_logP'])
        property_norms = compute_mean_mad(dataloaders, ['penalized_logP'], args.dataset)
    model = model.to(device)
    if prop_dist is not None:
        prop_dist.set_normalizer(property_norms)

    if args.ema_decay > 0:
        model_file = f'{ckpt_prefix}{model_name}_ema.npy'
    else:
        model_file = f'{ckpt_prefix}{model_name}.npy'

    flow_state_dict = torch.load(join(exp_folder, model_file), map_location=device)
    model.load_state_dict(flow_state_dict)
    model.eval()

    utilities_dict = {
        'nodes_dist': nodes_dist,
        'args': args,
        'device': device,
        'dataset_info': dataset_info,
        'property_norms': property_norms,
        'conditioning': args.conditioning if len(args.conditioning) > 0 else None,
        'prop_dist': prop_dist,
    }

    return model, utilities_dict


def load_regressor(exp_folder, regression_target):
    model_name = 'regression_model'
    for prop in regression_target:
        model_name += '_' + prop
    with open(join(exp_folder, f'args_{model_name}.pickle'), 'rb') as f:
        args = pickle.load(f)
    assert len(regression_target) > 0
    assert args.regression_target == regression_target

    if not hasattr(args, 'condition_time_regressor'):
        args.condition_time_regressor = False
    if not hasattr(args, 'max_step_regressor'):
        args.max_step_regressor = 1000
    if regression_target == 'penalized_logP':
        args.n_layers = 4

    device = torch.device("cuda" if args.cuda else "cpu")
    dtype = torch.float32

    dataloaders, charge_scale = dataset.retrieve_dataloaders(args, args.debug)

    if len(args.regression_target) > 0:
        print(f'Using regression target: {args.regression_target}')
        property_norms = compute_mean_mad(dataloaders, args.regression_target, args.dataset)
    context_node_nf = 0
    args.context_node_nf = context_node_nf
    
    dataset_info = get_dataset_info(args.dataset, args.remove_h)

    model, nodes_dist, prop_dist = get_diffusion_guidance(args, device, dataset_info, dataloaders['train'])

    model_state_dict = torch.load(f'{exp_folder}/{model_name}_ema.npy', map_location=device)
    model.load_state_dict(model_state_dict)
    regressor = model.regressor
    regressor = regressor.to(device)
    regressor.eval()

    prop_dist = {}
    for prop in regression_target:
        prop_dist[prop] = DistributionProperty(dataloaders['train'], [prop])
        property_norms = compute_mean_mad(dataloaders, [prop], args.dataset)
        prop_dist[prop].set_normalizer(property_norms)

    property_upperbounds = compute_properties_upper_bounds(dataloaders['train'], regression_target, log_to_wandb=False)
    print('property_upperbounds', property_upperbounds)

    return regressor, prop_dist


def load_model(model_path: str):
    """
    Args:
        model_path (str): path to folder containing model checkpoints. e.g. outputs/edm_zinc/
    """
    with open(join(model_path, 'args.pickle'), 'rb') as f:
        args = pickle.load(f)

    if not hasattr(args, 'joint_training'):
        args.joint_training = False

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Retrieve QM9 dataloaders
    dataloaders, charge_scale = dataset.retrieve_dataloaders(args)
    dataset_info = get_dataset_info(args.dataset, args.remove_h)

    # Create EGNN flow
    model, nodes_dist, prop_dist = get_model(args, device, dataset_info, dataloaders['train'])
    model = model.to(device)

    model_file = 'generative_model_ema.npy' if args.ema_decay > 0 else 'generative_model.npy'
    flow_state_dict = torch.load(join(model_path, model_file), map_location=device)
    model.load_state_dict(flow_state_dict)
    model.eval()

    utilities_dict = {
        'nodes_dist': nodes_dist,
        'args': args,
        'device': device,
        'dataset_info': dataset_info,
    }

    return model, utilities_dict
