import torch
from torch.nn import Parameter

import numpy as np

from .utils import neg_exp, cal_mmd

class ReferenceLayer(torch.nn.Module):
    def __init__(self, input_channels, output_channels, num_supp, gamma, init_atoms=None):
        """
        Args
        -----
        input_channels: int
            The dimension of the layer input.
        output_channels: int
            The dimension of the layer output.
        num_supp: int
            The number of support of the atoms.
        gamma: float
            The hyper-parameter used in gaussian kernel.
        random_init: bool
            Flag used to determine whether to initialize the atoms randomly from a gaussian distribution.
        init_atoms: list or torch.Tensor
            The initialization of the atoms if choose not to use random initialization. 
            Can be a 3-dimension tensor with shape (`output_channels`, `num_supp`, `input_channels`) OR
            a list with length equal to `output_channels`, and tensor with shape (arbitray, `input_channels`)
        """
        super().__init__()
        # self.input_channels = input_channels
        self.output_channels = output_channels
        # self.num_supp = num_supp
        self.init_gamma = gamma
        self.init_atoms = init_atoms
        self.atoms = Parameter(torch.empty(size=(output_channels, num_supp, input_channels)))
        self.gamma = Parameter(torch.empty(size=(1,)))
        # self.gamma = torch.empty(size=(1,))
        self.reset_parameters()

    def init_atoms_param(self, num_supp, random_init, init_atoms):
        if random_init:
            atoms = Parameter(torch.randn(size=(self.output_channels, num_supp, self.input_channels)))
        else:
            if isinstance(init_atoms, list):    
                atoms = [Parameter(atom) for atom in init_atoms]
            else:
                atoms = Parameter(init_atoms)
        return atoms
    
    def forward(self, x, node_slice):
        # print(x.shape, data)
        mmd_distance = cal_mmd(self.atoms, x, node_slice, self.gamma)

        return -mmd_distance
    
    def discriminate_loss(self):
        num_atoms = self.output_channels
        loss = 0
        for i in range(num_atoms-1):
            x = self.atoms[i]
            y = self.atoms[i+1:]
            d_xy = ((x[np.newaxis, :, np.newaxis, :] - y[:, np.newaxis, :, :])**2).sum(dim=-1)
            k_xy = neg_exp(d_xy, self.gamma)
            d_xx = ((x[:, np.newaxis, :] - x[np.newaxis, :, :])**2).sum(dim=-1)
            k_xx = neg_exp(d_xx, self.gamma)
            d_yy = ((y[:, :, np.newaxis, :] - y[:, np.newaxis, :, :])**2).sum(dim=-1)
            k_yy = neg_exp(d_yy, self.gamma)
            mmd_distance = k_xx.mean() + k_yy.mean(dim=-1).mean(dim=-1) - 2*k_xy.mean(dim=-1).mean(dim=-1)
            mmd_distance = mmd_distance.clamp(1e-7)
            mmd_distance = mmd_distance**0.5
            loss -= (mmd_distance).sum()   
  
        return 2 * loss/((num_atoms-1)*num_atoms)
        # return loss
    
    def reset_parameters(self):
        if self.init_atoms is not None:
            self.atoms = self.init_atoms
        else:
            self.atoms.data.normal_(0, 1)
        self.gamma.data.fill_(self.init_gamma)

