import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from .functions import quantize
from .functions import pact

class abs_activation(nn.Module):
    def __init__(self):
        super(abs_activation, self).__init__()

    def forward(self, x):
        return torch.abs(x)

class q_act(nn.Module):
    def __init__(self, bits):
        super(q_act, self).__init__()
        self.bits = bits

    def forward(self, x, s, zp=0):
        q = quantize.apply
        qa =  q(x, self.bits, s, zp)
        return qa

class PACT(nn.Module):
    def __init__(self):
        super().__init__()
        self.alpha = torch.nn.Parameter(torch.tensor(6.0, dtype=torch.float32))

    def forward(self, x):
        return pact.apply(x, self.alpha)

class cyc_relu(nn.Module):
    def __init__(self, bits):
        super(cyc_relu, self).__init__()
        self.bits = bits

    def forward(self, x, s):
        x = x/s
        x = (x + (2**(self.bits-1))) % (2**self.bits) - (2**(self.bits-1))
        x = torch.max(x, - 2 * 2**(self.bits - 1) - 2 * x)
        x = torch.min(x, 2 * 2**(self.bits - 1) - 2 * x)
        return x * s


if __name__ == '__main__':
    print("input")
