import torch
from torch.nn import Module, Sequential, ModuleList, Linear, Embedding, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.models.schnet import CFConv, ShiftedSoftplus, GaussianSmearing
from torch_geometric.nn.inits import uniform
from torch_scatter import scatter
from .common import radius_bond_graph
from .edgecnf import *
from .cnf_edge.spectral_norm import inplace_spectral_norm


class InteractionBlock(torch.nn.Module):
    def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
        super(InteractionBlock, self).__init__()
        mlp = Sequential(
            Linear(num_gaussians, num_filters),
            ShiftedSoftplus(),
            Linear(num_filters, num_filters),
        )
        self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff)
        self.act = ShiftedSoftplus()
        self.lin = Linear(hidden_channels, hidden_channels)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        x = self.conv(x, edge_index, edge_weight, edge_attr)
        x = self.act(x)
        x = self.lin(x)
        return x


class SchNetEncoder(Module):

    def __init__(self, hidden_channels=128, num_filters=128,
                num_interactions=6, edge_channels=50, cutoff=10.0):
        super().__init__()

        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_interactions = num_interactions
        self.cutoff = cutoff

        self.embedding = Embedding(100, hidden_channels, max_norm=10.0)

        self.interactions = ModuleList()
        for _ in range(num_interactions):
            block = InteractionBlock(hidden_channels, edge_channels,
                                     num_filters, cutoff)
            self.interactions.append(block)

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()
        # for interaction in self.interactions:
        #     interaction.reset_parameters()

    def forward(self, z, edge_index, edge_length, edge_attr):
        assert z.dim() == 1 and z.dtype == torch.long

        h = self.embedding(z)
        for interaction in self.interactions:
            h = h + interaction(h, edge_index, edge_length, edge_attr)

        return h


class EdgeEncoder(Module):

    def __init__(self, num_gaussians=50, cutoff=10.0):
        super().__init__()
        self.NUM_BOND_TYPES = 22
        self.cutoff = cutoff
        self.rbf = GaussianSmearing(start=0.0, stop=cutoff, num_gaussians=num_gaussians)
        self.bond_emb = Embedding(2, embedding_dim=num_gaussians)

    def forward(self, pos, edge_index, edge_type, batch):
        edge_index, edge_type, edge_length = radius_bond_graph(
            pos, edge_index, edge_type, self.cutoff, batch, 
            unspecified_type_number=0
        )
        edge_type = torch.where(
            (edge_type < self.NUM_BOND_TYPES) * (edge_type > 0),
            torch.ones_like(edge_type),
            torch.zeros_like(edge_type)
        )
        edge_attr = torch.cat([self.rbf(edge_length), self.bond_emb(edge_type)], dim=1)
        return edge_index, edge_attr, edge_length


class AtomwiseEnergy(Module):

    def __init__(self, dim=128, num_gaussians=50, cutoff=10.0, num_interactions=6):
        super().__init__()
        self.cutoff = cutoff
        self.edge_enc = EdgeEncoder(num_gaussians=num_gaussians, cutoff=cutoff)
        # self.emb_dist = GaussianSmearing(start=0.0, stop=cutoff, num_gaussians=num_gaussians)
        self.schnet = SchNetEncoder(
            hidden_channels=dim, 
            num_filters=dim,
            num_interactions=num_interactions,
            edge_channels=2 * num_gaussians,    # Distance and type
            # edge_channels=num_gaussians,
            cutoff=cutoff
        )
        self.ener_mlp = Sequential(
            Linear(dim, dim//2, bias=True),
            ShiftedSoftplus(),
            Linear(dim // 2, dim // 4, bias=True),
            ShiftedSoftplus(),
            Linear(dim // 4, 1),
        )

    def forward(self, node_type, pos, edge_index, edge_type, batch):
        # edge_index = radius_graph(pos, r=self.cutoff, batch=batch, loop=False)
        # edge_length = (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
        # edge_attr = self.emb_dist(edge_length)

        edge_index, edge_attr, edge_length = self.edge_enc(pos, edge_index, edge_type, batch)

        h = self.schnet(node_type, edge_index, edge_length, edge_attr)
        ener = self.ener_mlp(h)

        return ener


class EBM(Module):

    def __init__(self, args):
        super().__init__()
        self.energy_model = AtomwiseEnergy(
            dim=args.ebm_hidden_dim,
            num_gaussians=args.ebm_num_gaussians,
            cutoff=args.ebm_cutoff,
            num_interactions=args.ebm_num_layers
        )

    def enable_spectral_norm(self, logger=None):
        def apply_spectral_norm(module):
            if 'weight' in module._parameters:
                if logger is not None: logger.info("Adding spectral norm to {}".format(module))
                inplace_spectral_norm(module, 'weight')
        self.apply(apply_spectral_norm)

    def forward(self, *args, **kwargs):
        # print(data.pos.size(), data.batch.size())
        assert len(args) == 1 or len(kwargs) == 5
        if len(args) == 1:
            data = args[0]
        elif len(kwargs) == 5:
            data = kwargs
        ener = self.energy_model(data['node_type'], data['pos'], data['edge_index'], data['edge_type'], data['batch'])
        return ener

    def get_loss_cd(self, data_pos, data_neg, alpha=0):
        """
            Contrastive Divergence (CD)
        """
        ener_pos = self(data_pos)
        ener_neg = self(data_neg)
        
        loss_pos = ener_pos.mean()
        loss_neg = -1 * ener_neg.mean()
        loss = (loss_pos + loss_neg) / 2

        if alpha > 0:
            reg = ((ener_pos ** 2).mean() + (ener_neg ** 2).mean()) / 2
            loss = loss + alpha * reg

        return loss, {
            'ener_pos': ener_pos, 
            'ener_neg': ener_neg
        }
    
    def get_loss_nce(self, data_pos, data_neg, alpha=0, lambda_term=0):
        """
            Noise Contrastive Estimation (NCE)
        """
        if lambda_term > 0:
            data_pos.pos.requires_grad_(True)
            data_neg.pos.requires_grad_(True)
        
        ener_pos = self(data_pos)
        ener_neg = self(data_neg)
        
        loss_pos = -1 * torch.log(torch.sigmoid(-ener_pos)).mean()
        loss_neg = -1 * torch.log(torch.sigmoid(ener_neg)).mean()
        loss_nce = loss_pos + loss_neg

        loss_surface_penalty = 0
        loss_gradient_penalty = 0

        if alpha > 0:
            reg = ((ener_pos ** 2).mean() + (ener_neg ** 2).mean()) / 2
            loss_surface_penalty = alpha * reg

        if lambda_term > 0:
            grad_pos = torch.autograd.grad(outputs=[ener_pos.sum()], inputs=[data_pos.pos], create_graph=True, retain_graph=True)[0]
            grad_neg = torch.autograd.grad(outputs=[ener_neg.sum()], inputs=[data_neg.pos], create_graph=True, retain_graph=True)[0]
            loss_gradient_penalty = ((grad_pos ** 2).mean() + (grad_neg ** 2).mean()) / 2
        
        loss = loss_nce + loss_surface_penalty + loss_gradient_penalty

        return loss, {
            'ener_pos': ener_pos, 
            'ener_neg': ener_neg, 
            'loss_nce': loss_nce, 
            'loss_surface_penalty': loss_surface_penalty, 
            'loss_gradient_penalty': loss_gradient_penalty
        }

    def get_loss_mle(self, data_pos, data_neg, alpha=0, check=True):
        if check:
            assert data_neg.num_nodes % data_pos.num_nodes == 0
            for i, smiles in enumerate(data_neg.smiles):
                assert smiles == data_pos.smiles[i % data_pos.num_graphs]

        ener_pos = self(data_pos)
        ener_neg = self(data_neg).view(ener_pos.size(0), -1)

        partition = -1 * torch.cat([ener_pos, ener_neg], dim=1)
        log_prob = (-1 * ener_pos) - torch.logsumexp(partition, dim=1, keepdim=True)
        loss = (-1 * log_prob).mean()

        if alpha > 0:
            reg = ((ener_pos ** 2).mean() + (ener_neg ** 2).mean()) / 2
            loss = loss + alpha * reg

        return loss, {
            'ener_pos': ener_pos, 
            'ener_neg': ener_neg, 
            'log_prob': log_prob
        }

