import torch
import torch.nn as nn
import ipdb
import torch.nn.functional as F
import numpy as np

from models.hyper import OutNet


class MetaFun(nn.Module):
    def __init__(self, input_dim, output_dim, shift_input_dim, shift_output_dim, k, kernel_type):
        super().__init__()

        self.k = k
        self.kernel_type = kernel_type

        hidden = 128
        self.hidden = hidden
        self.h = nn.Sequential(
            nn.Linear(input_dim + output_dim, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, hidden))

        hidden = 128
        self.a_support = nn.Sequential(
            nn.Linear(input_dim, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, hidden))
        self.a_query = nn.Sequential(
            nn.Linear(shift_input_dim, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, hidden), 
            nn.ReLU(),
            nn.Linear(hidden, hidden))

        self.out_net = OutNet(hidden, shift_output_dim)

        self.num_params = self.get_n_params(self.h) + self.get_n_params(self.a_support) + self.get_n_params(self.a_query) + self.get_n_params(self.out_net)

    def kernel(self, x_s, x_q):
        # x_s \in num_support, dim
        # x_q \in num_query, dim

        a_s = self.a_support(x_s)
        a_q = self.a_query(x_q)
        if self.kernel_type == 'rbf':
            dist = torch.cdist(a_s, a_q) # Support size x query size
            kernel_out = torch.exp(- dist / 2)
        else:
            kernel_out = F.softmax(a_s @ a_q.T, -1) / np.sqrt(self.hidden)
        return kernel_out

    def get_n_params(self, model):
        pp=0
        for p in list(model.parameters()):
            nn=1
            for s in list(p.size()):
                nn = nn*s
            pp += nn
        return pp

    def forward(self, x_s, y_s, x_q):

        r = self.h(torch.cat((x_s, y_s), -1)) # num_support x dim
        k = self.kernel(x_s, x_q) # num_support x num_query

        out = k.T @ r # num_query x dim ## Directly the prediction # No encoder
        out = self.out_net(out)
        return out




if __name__ == '__main__':
    pass