import torch
import torch.nn as nn
import numpy as np
from methods.meta_template_gnn import MetaTemplate
from methods.gnn import GNN_nl
from methods import backbone


class GnnNet(MetaTemplate):
  def __init__(self, model_func, n_way, n_support,  n_layer, wdrop, ft, rest, fin_fc, tf_path=None):
    super(GnnNet, self).__init__(model_func, n_way, n_support, n_layer, wdrop, ft, rest, fin_fc, tf_path=tf_path)

    self.n_layer = n_layer
    self.wdrop = wdrop
    self.rest = rest
    self.ft = ft
    self.fin_fc = fin_fc

    # loss function
    self.loss_fn = nn.CrossEntropyLoss()
    # metric function
    self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128), nn.BatchNorm1d(128, track_running_stats=False)) if not self.ft else nn.Sequential(
      backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False))
    self.gnn = GNN_nl(input_features=128 + self.n_way, nf=96, train_N_way=self.n_way, n_supportn=self.n_support, list_layers=self.n_layer,
                      wdrop=self.wdrop, ft=self.ft, rest=self.rest, fin_fc=self.fin_fc)
    self.method = 'GnnNet'

    # fix label for training the metric function   1*nw(1 + ns)*nw
    support_label = torch.from_numpy(np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
    support_label = torch.zeros(self.n_way * self.n_support, self.n_way).scatter(1, support_label, 1).view(self.n_way,
                                                                                                           self.n_support,
                                                                                                           self.n_way)
    support_label = torch.cat([support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
    self.support_label = support_label.view(1, -1, self.n_way)

  def cuda(self):
    self.feature.cuda()
    self.fc.cuda()
    self.gnn.cuda()
    self.support_label = self.support_label.cuda()
    return self

  def set_forward(self, x, is_feature=False):
    x = x.cuda()
    if is_feature:
      # reshape the feature tensor: n_way * n_s + 15 * f
      assert (x.size(1) == self.n_support + self.n_query)
      z = self.fc( x.view(-1, *x.size()[2:]))
      z = z.view(self.n_way, -1, z.size(1))
    else:
      # get feature using encoder
      x = x.view(-1, *x.size()[2:])
      z = self.fc(self.feature(x))
      z = z.view(self.n_way, -1, z.size(1))

    # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
    z_stack = [
      torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]], dim=1).view(1, -1, z.size(2))
      for i in range(self.n_query)]
    assert (z_stack[0].size(1) == self.n_way * (self.n_support + 1))
    scores = self.forward_gnn(z_stack)
    return scores

  def forward_gnn(self, zs):
    # gnn inp: n_q * n_way(n_s + 1) * f
    nodes = torch.cat([torch.cat([z, self.support_label], dim=2) for z in zs], dim=0)
    scores= self.gnn(nodes)

    # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
    scores = scores.view(self.n_query, self.n_way, self.n_support + 1, self.n_way)[:, :, -1].permute(
              1, 0, 2).contiguous().view(-1, self.n_way)
    return scores

  def set_forward_loss(self, x):
    y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
    y_query = y_query.cuda()
    scores = self.set_forward(x)
    loss = self.loss_fn(scores, y_query)
    return scores, loss

