import numpy as np
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from model.encoder import DenseEncoder, DenseQueryEncoder
from model.model import BaseModel

K=2

class Query(BaseModel):
    """Returns a compiled model described in Appendix section C.

    Arguments:
        I (int): number of inputs in each program. input count is
            padded to I with null type and vals.
        E (int): embedding dimension
        M (int): number of examples per program. default 5.

    Returns:
        keras.model compiled keras model as described in Appendix
        section C
    """
    def __init__(self):
        super(Query, self).__init__()
        raise NotImplementedError

    def forward(self, x, typ, hard_softmax=False):
        raise NotImplementedError


class BN1d(nn.Module):
    def __init__(self, dim):
        super(BN1d, self).__init__()
        self.bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        size = x.size()
        x = x.view(x.shape[0], -1, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.bn(x)
        x = x.permute(0, 2, 1)
        x = x.view(*size)
        return x 

class my_round(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.input = input
        return torch.round(input)

    @staticmethod
    def backward(ctx, g):
        return g
