import torch
import torch.nn as nn
import math
from torch.autograd import Variable
import torch.nn.functional as F
from methods.backbone import Linear_fw, Conv2d_fw, BatchNorm2d_fw, BatchNorm1d_fw


if torch.cuda.is_available():
  dtype = torch.cuda.FloatTensor
  dtype_l = torch.cuda.LongTensor
else:
  dtype = torch.FloatTensor
  dtype_l = torch.cuda.LongTensor


def gmul(input):
  W, x = input
  W_size = W.size()
  N = W_size[-2]
  W = W.split(1, 3)
  W = torch.cat(W, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N)
  output = torch.bmm(W, x) # output has size (bs, J*N, num_features)
  output = output.split(N, 1)
  output = torch.cat(output, 2) # output has size (bs, N, J*num_features)
  return output


def f_loss_sub(f_few, f_much):
  loss = torch.mean(torch.abs(f_few - f_much))
  return loss

def f_loss_sub_var(f_few, f_much):
  f = f_few - f_much
  f_sub = torch.mean(torch.abs(f))
  f_var = torch.mean(torch.var(f, dim=1))
  loss = f_sub + f_var
  return loss


class Gconv(nn.Module):
  def __init__(self, nf_input, nf_output, J, ft, bn_bool=True, mlp=False, nf=96):
    super(Gconv, self).__init__()
    self.J = J
    self.ft = ft
    self.num_inputs = J*nf_input
    self.num_outputs = nf_output
    self.mlp = mlp
    self.fc = nn.Linear(self.num_inputs, self.num_outputs) if not self.ft else Linear_fw(self.num_inputs, self.num_outputs)
    self.bn_bool = bn_bool
    if self.bn_bool:
      self.bn = nn.BatchNorm1d(self.num_outputs, track_running_stats=False) if not self.ft else BatchNorm1d_fw(self.num_outputs, track_running_stats=False)

  def forward(self, input):
    x = gmul(input) # out has size (bs, N, num_inputs)
    x_size = x.size()
    x = x.contiguous()
    x = x.view(-1, self.num_inputs)
    if not self.bn_bool:
      if self.mlp:
        x_f = self.to_5shot(x)
        x = x_f
      else:
        x_f = x
    x = self.fc(x) # has size (bs*N, num_outputs)
    if self.bn_bool:
      x = self.bn(x)
    x = x.view(*x_size[:-1], self.num_outputs)
    if self.bn_bool:
      return x
    else:
      x_f = x_f.view(*x_size[:-1], self.num_inputs)
      return x, x_f


class Wcompute(nn.Module):
  def __init__(self, input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1], num_operators=1,
               drop=False, rest=False, n_wayw=5, n_supportw=1, ft=False):

    super(Wcompute, self).__init__()
    self.num_features = nf
    self.operator = operator
    self.rest = rest
    self.ft = ft
    self.n_wayw = n_wayw
    self.n_supportw = n_supportw

    self.conv2d_1 = nn.Conv2d(input_features, int(nf * ratio[0]), 1, stride=1) if not self.ft else Conv2d_fw(input_features, int(nf * ratio[0]), 1, stride=1)
    self.bn_1 = nn.BatchNorm2d(int(nf * ratio[0]), track_running_stats=False) if not self.ft else BatchNorm2d_fw(int(nf * ratio[0]), track_running_stats=False)
    self.drop = drop
    if self.drop:
      self.dropout = nn.Dropout(0.3)
    self.conv2d_2 = nn.Conv2d(int(nf * ratio[0]), int(nf * ratio[1]), 1, stride=1) if not self.ft else Conv2d_fw(int(nf * ratio[0]), int(nf * ratio[1]), 1, stride=1)
    self.bn_2 = nn.BatchNorm2d(int(nf * ratio[1]), track_running_stats=False) if not self.ft else BatchNorm2d_fw(int(nf * ratio[1]), track_running_stats=False)
    self.conv2d_3 = nn.Conv2d(int(nf * ratio[1]), nf * ratio[2], 1, stride=1) if not self.ft else Conv2d_fw(int(nf * ratio[1]), nf * ratio[2], 1, stride=1)
    self.bn_3 = nn.BatchNorm2d(nf * ratio[2], track_running_stats=False) if not self.ft else BatchNorm2d_fw(nf * ratio[2], track_running_stats=False)
    self.conv2d_4 = nn.Conv2d(nf * ratio[2], nf * ratio[3], 1, stride=1) if not self.ft else Conv2d_fw(nf * ratio[2], nf * ratio[3], 1, stride=1)
    self.bn_4 = nn.BatchNorm2d(nf * ratio[3], track_running_stats=False) if not self.ft else BatchNorm2d_fw(nf * ratio[3], track_running_stats=False)

    self.conv2d_last = nn.Conv2d(nf, num_operators, 1, stride=1) if not self.ft else Conv2d_fw(nf, num_operators, 1, stride=1)

    self.activation = activation

  def forward(self, x, W_id):

    W1 = x.unsqueeze(2)
    W2 = torch.transpose(W1, 1, 2) #size: bs x 1 x N x num_features

    W_new = torch.abs(W1 - W2) #size: bs x N x N x num_features
    W_new = torch.transpose(W_new, 1, 3) #size: bs x num_features x N x N

    W_new = self.conv2d_1(W_new)
    W_new = self.bn_1(W_new)
    W_new = F.leaky_relu(W_new)

    if self.drop:
      W_new = self.dropout(W_new)

    W_new = self.conv2d_2(W_new)
    W_new = self.bn_2(W_new)
    W_new = F.leaky_relu(W_new)


    W_new = self.conv2d_3(W_new)
    W_new = self.bn_3(W_new)
    W_new = F.leaky_relu(W_new)

    W_new = self.conv2d_4(W_new)
    W_new = self.bn_4(W_new)
    W_new = F.leaky_relu(W_new)

    W_new = self.conv2d_last(W_new)
    W_new = torch.transpose(W_new, 1, 3) #size: bs x N x N x 1

    if self.activation == 'softmax':
      W_new = W_new - W_id.expand_as(W_new) * 1e8
      W_new = torch.transpose(W_new, 2, 3)
      # Applying Softmax
      W_new = W_new.contiguous()
      W_new_size = W_new.size()
      W_new = W_new.view(-1, W_new.size(3))
      W_new = F.softmax(W_new, dim=1)
      W_new = W_new.view(W_new_size)
      # Softmax applied
      W_new = torch.transpose(W_new, 2, 3)
    elif self.activation == 'sigmoid':
      W_new = F.sigmoid(W_new)
      W_new *= (1 - W_id)
    elif self.activation == 'none':
      W_new *= (1 - W_id)
    else:
      raise (NotImplementedError)

    if self.operator == 'laplace':
      W_new = W_id - W_new
    elif self.operator == 'J2':
      W_new = torch.cat([W_id, W_new], 3)
    else:
      raise(NotImplementedError)

    return W_new


class Mlp(nn.Module):
  def __init__(self, input_features=458, n_way=5, nf=96, ratio=[3, 3, 4, 4], drop=False, rest=False, ft=False):
    super(Mlp, self).__init__()

    self.input_features = input_features
    self.drop = drop
    self.rest = rest
    self.ft = ft
    self.n_way = n_way

    self.fc_1 = nn.Linear(input_features, int(nf * ratio[0])) if not self.ft else Linear_fw(input_features, int(nf * ratio[0]))
    self.bn_1 = nn.BatchNorm1d(int(nf * ratio[0]), track_running_stats=False) if not self.ft else BatchNorm1d_fw(int(nf * ratio[0]), track_running_stats=False)
    self.fc_2 = nn.Linear(int(nf * ratio[0]), int(nf * ratio[1])) if not self.ft else Linear_fw(int(nf * ratio[0]), int(nf * ratio[1]))
    self.bn_2 = nn.BatchNorm1d(int(nf * ratio[1]), track_running_stats=False) if not self.ft else BatchNorm1d_fw(int(nf * ratio[1]), track_running_stats=False)
    self.fc_3 = nn.Linear(int(nf * ratio[1]), input_features) if not self.ft else Linear_fw(int(nf * ratio[1]),input_features)
    self.bn_3 = nn.BatchNorm1d(input_features, track_running_stats=False) if not self.ft else BatchNorm1d_fw(input_features, track_running_stats=False)
    self.fc_class = nn.Linear(input_features, n_way)

  def forward(self, f):
    f = F.leaky_relu(self.bn_1(self.fc_1(f)))
    f = F.leaky_relu(self.bn_2(self.fc_2(f)))
    f = F.leaky_relu(self.bn_3(self.fc_3(f)))
    scores = self.fc_class(f)
    return f, scores


class GnnBlock(nn.Module):
  def __init__(self, input_features, train_N_way, support, nf, wdrop, ft, rest, layers):
    super(GnnBlock, self).__init__()

    self.input_features = input_features

    self.n_support = support
    self.wdrop = wdrop
    self.ft = ft
    self.rest = rest
    self.num_outputs = train_N_way
    self.layers = layers
    module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1], drop=self.wdrop,
                        rest=self.rest, n_wayw=self.num_outputs, n_supportw=self.n_support, ft=self.ft)
    self.add_module('layer_w{}'.format(0), module_w)
    for i in range(self.layers):
      module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2, ft=self.ft)
      self.add_module('layer_l{}'.format(i), module_l)

  def forward(self, x):
    W_init = torch.eye(x.size(1), device=x.device).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)
    W0 = self._modules['layer_w{}'.format(0)](x, W_init)
    for i in range(self.layers):
      x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([W0, x]))
      x = torch.cat([x, x_new], 2)
    return x


class GNN_nl(nn.Module):
  def __init__(self, input_features, nf, train_N_way, n_supportn, wdrop, ft, rest, list_layers, fin_fc, mlp):
    super(GNN_nl, self).__init__()
    self.input_features = input_features
    self.nf = nf
    self.num_outputs = train_N_way
    self.n_supportn = n_supportn
    self.list_layers = list_layers
    self.len_list = len(self.list_layers)
    self.all_layers = sum(self.list_layers)
    self.wdrop = wdrop
    self.ft = ft
    self.rest = rest
    self.fin_fc = fin_fc
    self.mlp = mlp
    trunk = []
    input_block_feature = self.input_features
    for i in range(self.len_list):
      B = GnnBlock(input_features=input_block_feature, nf=self.nf, train_N_way=self.num_outputs, support=self.n_supportn,
                  wdrop=self.wdrop, ft=self.ft, rest=self.rest, layers=self.list_layers[i])
      trunk.append(B)
      input_block_feature = int(input_block_feature+(self.nf/2)*self.list_layers[i])
    self.trunk = nn.Sequential(*trunk)
    if self.fin_fc:
      self.fc = nn.Linear(self.input_features + int(self.nf / 2) * self.all_layers, train_N_way) if not self.ft else Linear_fw(
        self.input_features + int(self.nf / 2) * self.all_layers, train_N_way)
    else:
      self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.all_layers, nf, operator='J2', activation='softmax',ratio=[2, 2, 1, 1],
                                  rest=self.rest, ft=self.ft, n_wayw=self.num_outputs, n_supportw=self.n_supportn)
      self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.all_layers, self.num_outputs, 2, ft=self.ft, bn_bool=False, mlp=self.mlp, nf=nf)


  def forward(self, x):
    out = self.trunk(x)
    out_size = out.size()
    if self.fin_fc:
      out_fc = out.contiguous()
      out_fc = out_fc.view(-1, out_size[2])
      out_fc = self.fc(out_fc)
      out = out_fc.view(*out_size[:-1], self.num_outputs)
    else:
      W_init = torch.eye(x.size(1), device=x.device).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)
      Wl = self.w_comp_last(out, W_init)
      out, out_f = self.layer_last([Wl, out])
    out = out.view(*out_size[:-1], self.num_outputs)
    out_f = out_f.view(*out_size[:-1], -1)
    return out, out_f

