from os.path import join
import pickle

from bond_type_prediction.egnn_edge_model import EGNNEdgeModel
from configs.datasets_config import get_dataset_info
from qm9 import dataset

import torch


def get_pp_model(args, dataset_info, device, dataloader, recompute_class_weight):
    """
    Initializes a PostProcessing Model
    """
    # TODO: how to model formal chages: 1 or 3 dim??
    # as input the pp_model will get the charges feaures as a single scalar
    # as output it will predict the logits for the 3 classes: -1, 0, 1
    nf_charges = 1 if args.joint_training else 3
    in_node_nf = len(dataset_info['atom_decoder']) + int(args.include_charges)*nf_charges + int(args.condition_time_pp)
    out_node_nf = len(dataset_info['atom_decoder']) + int(args.include_charges)*3 # 3 for one-hot encoding of charges -1, 0, 1
    pp_model = EGNNEdgeModel(in_node_nf=in_node_nf, 
                        in_edge_nf=1, 
                        hidden_nf=args.hidden_nf_pp, 
                        device=device,
                        act_fn=torch.nn.SiLU(), 
                        n_layers=args.n_layers_pp, 
                        attention=args.attention_pp,
                        norm_diff=True, 
                        out_node_nf=out_node_nf, 
                        tanh=args.tanh_pp, 
                        coords_range=15, 
                        norm_constant=args.norm_constant_pp,
                        inv_sublayers=args.inv_sublayers_pp, 
                        sin_embedding=args.sin_embedding_pp, 
                        normalization_factor=args.normalization_factor_pp, 
                        aggregation_method=args.aggregation_method_pp,
                        include_charges=args.include_charges,
                        encoder=args.encoder_pp,
                        edge_head=args.edge_head_pp,
                        edge_head_hidden_dim=args.edge_head_hidden_dim_pp,
                        n_classes=5, # include aromatic types
                        modify_h=args.modify_h_pp,
                        joint_training=args.joint_training,
                        condition_time=args.condition_time_pp,
                        )
    pp_model.prepare_class_weights(dataloader, dataset_info, device, recompute_class_weight=recompute_class_weight)
    return pp_model


def load_trained_edge_model(model_path):
    """
    Initializes and loads the weights of an already trained EGNNEdgeModel
    Args:
        model_path (str): patht to the folder where the model weights and tha args are stored. E.g. outputs/edge_model_qm9_egnn/
    """
    with open(join(model_path, 'args.pickle'), 'rb') as f:
        args = pickle.load(f)

    if not hasattr(args, 'joint_training'):
        args.joint_training = False

    # Initialize Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device

    # Retrieve QM9 dataloaders
    dataloaders, charge_scale = dataset.retrieve_dataloaders(args)
    dataset_info = get_dataset_info(args.dataset, args.remove_h)

    pp_model = get_pp_model(args, dataset_info, device, dataloaders['train'], recompute_class_weight=False)

    # Load weights
    model_state_dict = torch.load(join(model_path, 'pp_model_ema.npy'), map_location=device)
    pp_model.load_state_dict(model_state_dict)

    pp_model.eval()
    return pp_model

    # if args.dataset == 'qm9':
    #     if args.remove_h:
    #         included_species = torch.Tensor([6, 7, 8, 9])
    #     else:
    #         included_species = torch.Tensor([1, 6, 7, 8, 9])
    # elif args.dataset == 'zinc250k':
    #     if args.remove_h:
    #         included_species = torch.Tensor([6, 7, 8, 9, 15, 16, 17, 35, 53])
    #     else:
    #         included_species = torch.Tensor([1, 6, 7, 8, 9, 15, 16, 17, 35, 53])
    # in_node_nf = len(dataset_info['atom_decoder']) + int(args.include_charges)*3
    # if not hasattr(args, 'edge_head_hidden_dim'):
    #     args.edge_head_hidden_dim = 64
    # model = EGNNEdgeModel(in_node_nf=in_node_nf, 
    #                     in_edge_nf=1, 
    #                     hidden_nf=args.hidden_nf, 
    #                     device=device,
    #                     act_fn=torch.nn.SiLU(), 
    #                     n_layers=args.n_layers, 
    #                     attention=args.attention,
    #                     norm_diff=True, 
    #                     out_node_nf=None, 
    #                     tanh=args.tanh, 
    #                     coords_range=15, 
    #                     norm_constant=args.norm_constant,
    #                     inv_sublayers=args.inv_sublayers, 
    #                     sin_embedding=args.sin_embedding, 
    #                     normalization_factor=args.normalization_factor, 
    #                     aggregation_method=args.aggregation_method,
    #                     include_charges=args.include_charges,
    #                     encoder=args.encoder,
    #                     edge_head=args.edge_head,
    #                     edge_head_hidden_dim=args.edge_head_hidden_dim,
    #                     included_species=included_species,
    #                     modify_h=args.modify_h,
    #                     )
