import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function

class STE(Function):
    @staticmethod
    def forward(ctx, inputs, prox_op):
        return prox_op(inputs)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


class ProxGradLinear(nn.Linear):
    def __init__(self, *args, prox_op, **kwargs):
        super().__init__(*args, **kwargs)

        self.prox_op = prox_op

    def mofidy_prox_op(self, prox_op):
        self.prox_op = prox_op

    def freeze_weights(self):
        self.weight.data = self.prox_op(self.weight.data)
        self.weight.requires_grad = False
        self.weight.grad = None  # This is a workaround to stop the momentum of the optimizer to modify the weights

    def freeze_prox_op_weights(self):
        self.prox_op.freeze_weights()
    
    def clip_weights(self, _min, _max):
        self.weight.data.clamp_(_min, _max)

    def forward(self, inputs):
        if self.training:
            return F.linear(inputs, self.weight, self.bias)
        else: 
            return F.linear(inputs, self.prox_op(self.weight), self.bias)  # Project weights to quantization points for testing

class DualAveragingLinear(nn.Linear):
    def __init__(self, *args, prox_op, **kwargs):
        super().__init__(*args, **kwargs)

        self.prox_op = prox_op
        self._ste = STE(prox_op).apply
   
    def mofidy_prox_op(self, prox_op):
        self.prox_op = prox_op
    
    def freeze_weights(self):
        self.weight.data = self.prox_op(self.weight.data)
        self.weight.requires_grad = False
        self.weight.grad = None  # This is a workaround to stop the momentum of the optimizer to modify the weights
    
    def freeze_prox_op_weights(self):
        self.prox_op.freeze_weights()

    def clip_weights(self, _min, _max):
        self.weight.data.clamp_(_min, _max)

    def forward(self, inputs):
        return F.linear(inputs, self._ste(self.weight, self.prox_op), self.bias)
    
class ProxLinear(nn.Linear):
    def __init__(self, *args, prox_op, **kwargs):
        super().__init__(*args, **kwargs)

        self.prox_op = prox_op
        self._ste = STE(prox_op).apply
   
    def mofidy_prox_op(self, prox_op):
        self.prox_op = prox_op
    
    def freeze_weights(self):
        self.weight.data = self.prox_op(self.weight.data)
        self.weight.requires_grad = False
        self.weight.grad = None  # This is a workaround to stop the momentum of the optimizer to modify the weights
    
    def freeze_prox_op_weights(self):
        self.prox_op.freeze_weights()

    def clip_weights(self, _min, _max):
        self.weight.data.clamp_(_min, _max)

    def forward(self, inputs):
        return F.linear(inputs, self._ste(self.weight, self.prox_op), self.bias)

