from grover.util.nn_utils import initialize_weights
from argparse import Namespace
from typing import List, Dict, Callable

import numpy as np
import torch
from torch import nn as nn

from grover.data import get_atom_fdim, get_bond_fdim
from grover.model.layers import Readout, GTransEncoder
from grover.util.nn_utils import get_activation_function


def build_model(args):
    model = GroverContrastive(args=args)
    initialize_weights(model=model)
    return model

def load_checkpoint(path):
    state = torch.load(path, map_location=lambda storage, loc: storage)
    args, loaded_state_dict = state['args'], state['state_dict']

    m_gnn = build_model(args=args)
    model_state_dict = m_gnn.state_dict()

    pretrained_state_dict = {}
    for param_name in loaded_state_dict.keys():
        new_param_name = param_name
        if new_param_name not in model_state_dict:
            print(f'Pretrained parameter "{param_name}" cannot be found in model parameters.')
        elif model_state_dict[new_param_name].shape != loaded_state_dict[param_name].shape:
            print(f'Pretrained parameter "{param_name}" '
                    f'of shape {loaded_state_dict[param_name].shape} does not match corresponding '
                    f'model parameter of shape {model_state_dict[new_param_name].shape}.')
        else:
            print(f'Loading pretrained parameter "{param_name}".')
            pretrained_state_dict[new_param_name] = loaded_state_dict[param_name]

    model_state_dict.update(pretrained_state_dict)
    m_gnn.load_state_dict(model_state_dict)

    return m_gnn


class GroverContrastive(nn.Module):
    """
    GroverFpGeneration class.
    It loads the pre-trained model and produce the fingerprints for input molecules.
    """
    def __init__(self, args):
        """
        Init function.
        :param args: the arguments.
        """
        super(GroverContrastive, self).__init__()

        args.dropout = 0.0
        args.cuda = True
        self.args = args

        self.grover = GROVEREmbedding(args)
        self.readout = Readout(rtype="mean", hidden_size=1200)

    def forward(self, batch):
        """
        The forward function.
        It takes graph batch and molecular feature batch as input and produce the fingerprints of this molecules.
        :param batch:
        :param features_batch:
        :return:
        """
        _, _, _, _, _, a_scope, _, _ = batch

        output = self.grover(batch)
        # Share readout
        mol_atom_from_bond_output = self.readout(output["atom_from_bond"], a_scope)
        mol_atom_from_atom_output = self.readout(output["atom_from_atom"], a_scope)

        return mol_atom_from_bond_output, mol_atom_from_atom_output



class GROVEREmbedding(nn.Module):
    """
    The GROVER Embedding class. It contains the GTransEncoder.
    This GTransEncoder can be replaced by any validate encoders.
    """

    def __init__(self, args: Namespace):
        """
        Initialize the GROVEREmbedding class.
        :param args:
        """
        super(GROVEREmbedding, self).__init__()
        self.embedding_output_type = args.embedding_output_type
        edge_dim = get_bond_fdim() + get_atom_fdim()
        node_dim = get_atom_fdim()

        # dualtrans is the old name.
        self.encoders = GTransEncoder(args,
                                        hidden_size=args.hidden_size,
                                        edge_fdim=edge_dim,
                                        node_fdim=node_dim,
                                        dropout=args.dropout,
                                        activation=args.activation,
                                        num_mt_block=args.num_mt_block,
                                        num_attn_head=args.num_attn_head,
                                        atom_emb_output=self.embedding_output_type,
                                        bias=args.bias,
                                        cuda=args.cuda)

    def forward(self, graph_batch: List) -> Dict:
        """
        The forward function takes graph_batch as input and output a dict. The content of the dict is decided by
        self.embedding_output_type.

        :param graph_batch: the input graph batch generated by MolCollator.
        :return: a dict containing the embedding results.
        """
        output = self.encoders(graph_batch)
        if self.embedding_output_type == 'atom':
            return {"atom_from_atom": output[0], "atom_from_bond": output[1],
                    "bond_from_atom": None, "bond_from_bond": None}  # atom_from_atom, atom_from_bond
        elif self.embedding_output_type == 'bond':
            return {"atom_from_atom": None, "atom_from_bond": None,
                    "bond_from_atom": output[0], "bond_from_bond": output[1]}  # bond_from_atom, bond_from_bond
        elif self.embedding_output_type == "both":
            return {"atom_from_atom": output[0][0], "bond_from_atom": output[0][1],
                    "atom_from_bond": output[1][0], "bond_from_bond": output[1][1]}
