import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable

from code_utils.model import euclidean_dist


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class Protonet(nn.Module):
    def __init__(self, encoder, device):
        super(Protonet, self).__init__()

        self.encoder = encoder
        self.device = device

    def forward(self, sample):
        # print('proto forward')
        xs = Variable(torch.from_numpy(sample['xs']['triplets']))  # support
        xq = Variable(torch.from_numpy(sample['xq']['triplets']))  # query

        xs_sub_hist = sample['xs']['s_hist']
        xq_sub_hist = sample['xq']['s_hist']

        xs_obj_hist = sample['xs']['o_hist']
        xq_obj_hist = sample['xq']['o_hist']

        n_class = 1
        n_support = xs.size(0)

        x = torch.cat([xs, xq], 0)
        s_hist = [xs_sub_hist[0] + xq_sub_hist[0], xs_sub_hist[1] + xq_sub_hist[1]]
        o_hist = [xs_obj_hist[0] + xq_obj_hist[0], xs_obj_hist[1] + xq_obj_hist[1]]
        z = self.encoder.forward(x, s_hist, o_hist, n_support)
        z_dim = z.size(-1)

        z_proto = z[:n_support * n_class].view(n_class, n_support, z_dim).mean(1)
        zq = z[n_class * n_support:]

        return z_proto, zq

    def loss(self, sample):
        # print('proto loss')
        xs = Variable(torch.from_numpy(sample['xs']['triplets']))  # support
        xq = Variable(torch.from_numpy(sample['xq']['triplets']))  # query

        xs_sub_hist = sample['xs']['s_hist']
        xq_sub_hist = sample['xq']['s_hist']

        xs_obj_hist = sample['xs']['o_hist']
        xq_obj_hist = sample['xq']['o_hist']

        n_class = 1
        n_support = xs.size(0)

        x = torch.cat([xs, xq], 0)
        s_hist = [xs_sub_hist[0] + xq_sub_hist[0], xs_sub_hist[1] + xq_sub_hist[1]]
        o_hist = [xs_obj_hist[0] + xq_obj_hist[0], xs_obj_hist[1] + xq_obj_hist[1]]
        z = self.encoder.forward(x, s_hist, o_hist, n_support)
        z_dim = z.size(-1)


        z_proto = z[:n_support*n_class].view(n_class, n_support, z_dim).mean(1)
        zq = z[n_class * n_support:]

        dists = euclidean_dist(zq, z_proto)
        # log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)

        # loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        # _, y_hat = log_p_y.max(2)
        # acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
        return dists





