import torch
import torch.nn as nn
import numpy as np
from methods.meta_template import MetaTemplate
from methods.gnn import GNN_nl
from methods import backbone


class GnnNet(MetaTemplate):
    maml = False

    def __init__(self, model_func,  n_way, n_support, tf_path=None, device=None):
        super(GnnNet, self).__init__(model_func,
                                     n_way, n_support, tf_path=tf_path, device=device)

        self.device = device

        # loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # metric function
        self.fc = nn.Sequential(nn.Linear(self.feat_dim, 128), nn.BatchNorm1d(128, track_running_stats=False)) if not self.maml else nn.Sequential(
            backbone.Linear_fw(self.feat_dim, 128), backbone.BatchNorm1d_fw(128, track_running_stats=False))
        self.gnn = GNN_nl(128 + self.n_way, 96, self.n_way)
        self.method = 'GnnNet'

        # fix label for training the metric function   1*nw(1 + ns)*nw
        support_label = torch.from_numpy(
            np.repeat(range(self.n_way), self.n_support)).unsqueeze(1)
        support_label = torch.zeros(self.n_way*self.n_support, self.n_way).scatter(
            1, support_label, 1).view(self.n_way, self.n_support, self.n_way)
        support_label = torch.cat(
            [support_label, torch.zeros(self.n_way, 1, n_way)], dim=1)
        self.support_label = support_label.view(1, -1, self.n_way)

    def to(self, device=None):
        self.feature.to(device=device)
        self.fc.to(device=device)
        self.gnn.to(device=device)
        self.support_label = self.support_label.to(device=device)
        return self

    def set_forward(self, x, is_feature=False):
        x = x.to(device=self.device)

        if is_feature:
            # reshape the feature tensor: n_way * n_s + 15 * f
            assert(x.size(1) == self.n_support + 15)
            z = self.fc(x.view(-1, *x.size()[2:]))
            z = z.view(self.n_way, -1, z.size(1))
        else:
            # get feature using encoder
            x = x.view(-1, *x.size()[2:])
            z = self.fc(self.feature(x))
            z = z.view(self.n_way, -1, z.size(1))

        # stack the feature for metric function: n_way * n_s + n_q * f -> n_q * [1 * n_way(n_s + 1) * f]
        z_stack = [torch.cat([z[:, :self.n_support], z[:, self.n_support + i:self.n_support + i + 1]],
                             dim=1).view(1, -1, z.size(2)) for i in range(self.n_query)]
        assert(z_stack[0].size(1) == self.n_way*(self.n_support + 1))
        scores = self.forward_gnn(z_stack)
        return scores

    def forward_gnn(self, zs):
        # gnn inp: n_q * n_way(n_s + 1) * f
        nodes = torch.cat([torch.cat([z, self.support_label], dim=2)
                           for z in zs], dim=0)
        scores = self.gnn(nodes)

        # n_q * n_way(n_s + 1) * n_way -> (n_way * n_q) * n_way
        scores = scores.view(self.n_query, self.n_way, self.n_support + 1, self.n_way)[
            :, :, -1].permute(1, 0, 2).contiguous().view(-1, self.n_way)
        return scores

    def set_forward_loss(self, x):
        y_query = torch.from_numpy(np.repeat(range(self.n_way), self.n_query))
        y_query = y_query.to(device=self.device)
        scores = self.set_forward(x)
        loss = self.loss_fn(scores, y_query)
        return scores, loss
