import torch
import torch.nn as nn
from .normal_ops_withbn import *

NA_PRIMITIVES = [
  'gcn',
  'gin',
  'gat',
]

SC_PRIMITIVES=[
  'none'
  # 'skip',
]
LA_PRIMITIVES=[
  'l_concat',
]

class NaMixedOp(nn.Module):

  def __init__(self, in_dim, out_dim, with_linear):
    super(NaMixedOp, self).__init__()
    self._ops = nn.ModuleList()
    self.with_linear = with_linear

    for primitive in NA_PRIMITIVES:
        op = NA_OPS[primitive](in_dim, out_dim)
        self._ops.append(op)

        if with_linear:
            self._ops_linear = nn.ModuleList()
            op_linear = torch.nn.Linear(in_dim, out_dim)
            self._ops_linear.append(op_linear)

  def forward(self, x, weights, edge_index, ):
    mixed_res = []
    if self.with_linear:
        for w, op, linear in zip(weights, self._ops, self._ops_linear):
            mixed_res.append(w * F.elu(op(x, edge_index)+linear(x)))
    else:
        for w, op in zip(weights, self._ops):
            mixed_res.append(w * F.elu(op(x, edge_index)))
    return sum(mixed_res)

class ScMixedOp(nn.Module):

  def __init__(self):
    super(ScMixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in SC_PRIMITIVES:
        op = SC_OPS[primitive]()
        self._ops.append(op)

  def forward(self, x, weights):
    mixed_res = []
    for w, op in zip(weights, self._ops):
        mixed_res.append(w * op(x))
    return sum(mixed_res)

class LaMixedOp(nn.Module):

  def __init__(self, hidden_size, num_layers=None):
    super(LaMixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in LA_PRIMITIVES:
        op = LA_OPS[primitive](hidden_size, num_layers)
        self._ops.append(op)

  def forward(self, x, weights):
    mixed_res = []
    for w, op in zip(weights, self._ops):
        mixed_res.append(w * F.relu(op(x)))
    return sum(mixed_res)