import torch
import torch.nn as nn
import global_time
import time
from torch.nn import functional as F
import sys

L=15
C=2**L
D=2**(-L)
P= 4293999991
# blind function implementation
def blind(mixture, bm):
    #torch.cuda.synchronize()
    #s = time.time()
    blind_size = bm.size()[1]
    split_size = int(mixture.size(0) / blind_size)
    new_mix = list(torch.split(mixture, split_size, dim=0))
    new_mix[-1] =  global_time.mean + global_time.std* torch.randn(new_mix[0].size(), device=new_mix[0].device, dtype=new_mix[0].dtype)
    new_mix = torch.cat(new_mix, axis=0)
    retval =  bm.matmul(new_mix.reshape(blind_size,-1)).reshape(new_mix.size())
    #torch.cuda.synchronize()
    #global_time.b_time += time.time() - s
    return retval

# unblind function implementation
def unblind(mixture, um):
    #torch.cuda.synchronize()
    #s = time.time()
    unblind_size = um.size()[1]
    retval = um.matmul(mixture.reshape(unblind_size,-1)).reshape(mixture.size())
    #torch.cuda.synchronize()
    #global_time.ub_time += time.time() - s
    return retval

# scrambling function for gradients
def scramble(mixture, gm):
    #torch.cuda.synchronize()
    #s = time.time()
    unblind_size = gm.size()[0]
    retval = gm.matmul(mixture.reshape(unblind_size,-1)[0:+unblind_size-1]).reshape(mixture.size())
    #torch.cuda.synchronize()
    #global_time.ub_time += time.time() - s
    return retval

def bias_add(input, bias):
    #torch.cuda.synchronize()
    #s = time.time()
    output = None
    if input.dim() == 2:
        output = input + bias
    else:
        expand_size = input.size(2) * input.size(3)
        expaned_bias = bias.reshape((bias.size(0), 1)).expand(bias.size(0), expand_size)
        output = (input.reshape((input.size(0), input.size(1),-1)) + expaned_bias).reshape(input.size())

    #torch.cuda.synchronize()
    #global_time.bias_time += time.time() - s
    return output
# blind function autograd
class Blind(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mixture, bm, inv_gm):
        torch.cuda.synchronize()
        s = time.time()
        blinded = blind(mixture, bm)
        ctx.save_for_backward(inv_gm)
        torch.cuda.synchronize()
        #global_time.b_time += time.time() - s
        return blinded

    @staticmethod
    def backward(ctx, grad_output):
        # retrive the saved inverse_grad matrix
        inverse_grad, = ctx.saved_tensors

        # decode the blinded gradient
        clean_gradient = unblind(grad_output, inverse_grad)
        return clean_gradient, None, None

# unblind function autograd
class UnBlind(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mixture, um, gm, bias):
        #torch.cuda.synchronize()
        #s = time.time()
        # save gradient matrix for backward pass
        ctx.save_for_backward(gm, bias)
        # return the unblinded input plus the bias
        retval = unblind(mixture, um)
        if bias is not None:
            retval = bias_add(retval, bias)
        #torch.cuda.synchronize()
        #global_time.ub_time += time.time() - s
        return retval
    @staticmethod
    def backward(ctx, grad_output):
        saved_gm, bias = ctx.saved_tensors
        # encode the gradient
        shuffled_grad_output = scramble(grad_output, saved_gm)
        # compute the gradients w.r.t. bias for the next layer
        # 1. get the correct axis
        if grad_output.dim() != 4:
            axis = 0
        else:
            axis = (0, 2, 3)
        # summation of the axis
        if bias is not None:
            bias_gradient = torch.sum(grad_output, axis=axis)
        else:
            bias_gradient = None

        return shuffled_grad_output, None, None, bias_gradient


class Conv2d_sp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
#        print("teeeeeeeeeeeeeeeeeeeest")
        stride=1
        padding=1
        dilation=1
        groups=1
        #print("xxxxx", input.shape, "   ", weight.shape, "  ", bias.shape)
        #r = torch.round((2**10) * torch.rand(1, input.shape[1], input.shape[2], input.shape[3], device = input.device))
        #input_1 = torch.fmod((input + r),P)
        #input_1 = torch.cat((input_1,r), 0)
        r =  global_time.mean + global_time.std* torch.rand(1, input.shape[1], input.shape[2], input.shape[3], device = input.device)
        input_1 = (input + r)
        input_1 = torch.cat((input_1,r), 0)
        if (bias!=None):
            temp = F.conv2d(torch.fmod(torch.round(input_1*C),P), torch.fmod(torch.round(weight*C),P), torch.fmod(torch.round(bias*C*C), P), stride, padding, dilation, groups)
        else:
            temp = F.conv2d(torch.fmod(torch.round(input_1*C),P), torch.fmod(torch.round(weight*C),P), bias, stride, padding, dilation, groups)
        #temp = torch.fmod(torch.round(output*D),P)
        #temp = torch.fmod((temp - temp[temp.shape[0]-1]), P)
        #output = temp*D
        #output = torch.narrow(output, 0, 0, output.shape[0]-1)
        output = temp - temp[temp.shape[0]-1]
        output = torch.narrow(output, 0, 0, output.shape[0]-1)
        ctx.save_for_backward(input, weight, output)
        #output = torch.round(output*D)*D
        if bias!=None:
            return output, bias.detach().clone()
        else:
            return output, bias
    @staticmethod
    def backward(ctx, grad_weight, grad_bias):
        stride=1
        padding=1
        dilation=1
        groups=1
        input, weight, output = ctx.saved_tensors
        grad_to_input = nn.grad.conv2d_input(input.shape, torch.fmod(torch.round(weight*C),P), torch.fmod(torch.round(grad_weight*C),P), stride=stride, padding=padding, dilation=dilation, groups=groups)
        grad_to_input = torch.fmod(torch.round(grad_to_input*D), P)*D
        grad_to_weight = nn.grad.conv2d_weight(torch.fmod(torch.round(input*C),P), weight.shape, torch.fmod(torch.round(grad_weight*C),P), stride=stride, padding=padding, dilation=dilation, groups=groups)
        grad_to_weight = torch.fmod(torch.round(grad_to_weight*D), P)*D
        return grad_to_input, grad_to_weight, grad_bias
        #return grad_to_input, grad_to_weight



class Conv2d_mob(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
        bias = None
        output = F.conv2d(torch.fmod(torch.round(input*C),P), torch.fmod(torch.round(weight*C),P), bias = bias, stride = stride, padding = padding, dilation = dilation, groups = groups)
        temp = torch.fmod(torch.round(output*D),P)
        output = temp*D
        ctx.save_for_backward(input, weight, bias, torch.IntTensor(stride), torch.IntTensor(padding), torch.IntTensor(dilation), torch.IntTensor(groups))
        return output
    @staticmethod
    def backward(ctx, grad_weight):
        input, weight, bias, stride, padding, dilation, groups = ctx.saved_tensors
        #print("iiiiiiii", (stride.numpy()[0]),(padding.numpy()[0]), (dilation.numpy()[0]), groups.numpy()[0])
        grad_to_input = nn.grad.conv2d_input(input.shape, torch.fmod(torch.round(weight*C),P), torch.fmod(torch.round(grad_weight*C),P), stride=stride.numpy()[0], padding=padding.numpy()[0], dilation=dilation.numpy()[0], groups=1)
        #print("ttttttttttt2")
        grad_to_input = torch.fmod(torch.round(grad_to_input*D), P)*D
        grad_to_weight = nn.grad.conv2d_weight(torch.fmod(torch.round(input*C),P), weight.shape, torch.fmod(torch.round(grad_weight*C),P), stride=stride.data.numpy()[0], padding=padding.numpy()[0], dilation=dilation.numpy()[0], groups=1)
        grad_to_weight = torch.fmod(torch.round(grad_to_weight*D), P)*D
        #print("ttttt3333333333")
        return grad_to_input, grad_to_weight, bias, None, None, None, None

class Conv2d_res(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding, dilation, groups):
        bias = None
        r =  global_time.mean + global_time.std* torch.rand(1, input.shape[1], input.shape[2], input.shape[3], device = input.device)
        input_1 = (input + r)
        input_1 = torch.cat((input_1,r), 0)
        temp = F.conv2d(input_1, weight,  bias = bias, stride = stride, padding = padding, dilation = dilation, groups = groups)
        output = temp - temp[temp.shape[0]-1]
        output = torch.narrow(output, 0, 0, output.shape[0]-1)
        #output = temp*D

        #output = F.conv2d(torch.fmod(torch.round(input*C),P), torch.fmod(torch.round(weight*C),P), bias = bias, stride = stride, padding = padding, dilation = dilation, groups = groups)
        #temp = torch.fmod(torch.round(output*D),P)
        #output = temp*D
        ctx.save_for_backward(input, weight, bias, torch.IntTensor(stride), torch.IntTensor(padding), torch.IntTensor(dilation), torch.IntTensor(groups))
        return output
    @staticmethod
    def backward(ctx, grad_weight):
        input, weight, bias, stride, padding, dilation, groups = ctx.saved_tensors
        #grad_to_input = nn.grad.conv2d_input(input.shape, torch.fmod(torch.round(weight*C),P), torch.fmod(torch.round(grad_weight*C),P), stride=stride.numpy()[0], padding=padding.numpy()[0], dilation=dilation.numpy()[0], groups=1)
        #grad_to_input = torch.fmod(torch.round(grad_to_input*D), P)*D
        #grad_to_weight = nn.grad.conv2d_weight(torch.fmod(torch.round(input*C),P), weight.shape, torch.fmod(torch.round(grad_weight*C),P), stride=stride.data.numpy()[0], padding=padding.numpy()[0], dilation=dilation.numpy()[0], groups=1)
        #grad_to_weight = torch.fmod(torch.round(grad_to_weight*D), P)*D

        grad_to_input = nn.grad.conv2d_input(input.shape, weight, grad_weight, stride=stride.numpy()[0], padding=padding.numpy()[0], dilation=dilation.numpy()[0], groups=1)
        #grad_to_input = torch.fmod(torch.round(grad_to_input*D), P)*D
        grad_to_weight = nn.grad.conv2d_weight(input, weight.shape, grad_weight, stride=stride.data.numpy()[0], padding=padding.numpy()[0], dilation=dilation.numpy()[0], groups=1)
        #grad_to_weight = torch.fmod(torch.round(grad_to_weight*D), P)*D
        return grad_to_input, grad_to_weight, bias, None, None, None, None


def quantize(x):
    return (x+0.5).astype(int)

def rounding(weight):
   return quantize(wieght*(2**L))

def unrounding(x):
    temp = rounding(x*(2**(-L)))
    return (temp*(2**(-L)))

class Linear_sp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        #print("teeeeeeeeeeeeeeeeeeeesssssssssssssttttttttttt")
        #r =  global_time.mean + global_time.std* torch.randn(1,input.shape[1],device=input.device)
        #input_1 = (input + r)
        #input_1 = torch.cat((input_1,r), 0)
        #output = F.linear(torch.round(input*C), torch.round(weight*C), (torch.round(bias*C*C))
        #temp = torch.round(output*D)
        output = F.linear(torch.fmod(torch.round(input*C), P), torch.fmod(torch.round(weight*C), P), torch.fmod(torch.round(bias*C*C),P))
        temp = torch.fmod(torch.round(output*D),P)
        #temp = torch.fmod((temp - temp[temp.shape[0]-1]), P)
        y = temp*D
        return y,bias
    
    @staticmethod
    def backward(ctx, grad_weight, grad_bias):
        input, weight = ctx.saved_tensors
        grad_weight = torch.fmod(torch.round(grad_weight*C),P)
        grad_to_input = grad_weight.matmul(torch.fmod(torch.round(weight*C), P))
        grad_to_input = torch.fmod(torch.round(grad_to_input*D)*D, P)
        grad_to_weight= grad_weight.transpose(0, 1).matmul(torch.fmod(torch.round(input*C), P))
        grad_to_weight = torch.fmod(torch.round(grad_to_weight*D),P)*D
        return grad_to_input, grad_to_weight, grad_bias

class Linear_spr(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        #print(input.shape)
        #print("teeeeeeeeeeeeeeeeeeeesssssssssssssttttttttttt")
        r =  global_time.mean + global_time.std* torch.randn(1,input.shape[1],device=input.device)
        #r = torch.round((2**10) * torch.rand(1, input.shape[1], device = input.device))
        input_1 = (input + r)
        input_1 = torch.cat((input_1,r), 0)
        temp = F.linear(input_1, weight, bias)
        #temp = torch.fmod(torch.round(output*D),P)
        output = (temp - temp[temp.shape[0]-1])
        #y = temp*D
        output = torch.narrow(output, 0, 0, output.shape[0]-1)
        return output,bias

    @staticmethod
    def backward(ctx, grad_weight, grad_bias):
        input, weight = ctx.saved_tensors
        #grad_weight = torch.fmod(torch.round(grad_weight*C),P)
        grad_to_input = grad_weight.matmul(weight)
        #grad_to_input = torch.fmod(torch.round(grad_to_input*D)*D, P)
        grad_to_weight= grad_weight.transpose(0, 1).matmul(input)
        #grad_to_weight = torch.fmod(torch.round(grad_to_weight*D),P)*D
        return grad_to_input, grad_to_weight, grad_bias

