import torch
import torch.nn as nn
import global_time
import time
from torch.nn import functional as F
import sys

# blind function implementation
def blind(mixture, bm):
    #torch.cuda.synchronize()
    #s = time.time()
    blind_size = bm.size()[1]
    split_size = int(mixture.size(0) / blind_size)
    new_mix = list(torch.split(mixture, split_size, dim=0))
    new_mix[-1] =  global_time.mean + global_time.std* torch.randn(new_mix[0].size(), device=new_mix[0].device, dtype=new_mix[0].dtype)
    new_mix = torch.cat(new_mix, axis=0)
    retval =  bm.matmul(new_mix.reshape(blind_size,-1)).reshape(new_mix.size())
    #torch.cuda.synchronize()
    #global_time.b_time += time.time() - s
    return retval

# unblind function implementation
def unblind(mixture, um):
    #torch.cuda.synchronize()
    #s = time.time()
    unblind_size = um.size()[1]
    retval = um.matmul(mixture.reshape(unblind_size,-1)).reshape(mixture.size())
    #torch.cuda.synchronize()
    #global_time.ub_time += time.time() - s
    return retval

# scrambling function for gradients
def scramble(mixture, gm):
    #torch.cuda.synchronize()
    #s = time.time()
    unblind_size = gm.size()[0]
    retval = gm.matmul(mixture.reshape(unblind_size,-1)[0:+unblind_size-1]).reshape(mixture.size())
    #torch.cuda.synchronize()
    #global_time.ub_time += time.time() - s
    return retval

def bias_add(input, bias):
    #torch.cuda.synchronize()
    #s = time.time()
    output = None
    if input.dim() == 2:
        output = input + bias
    else:
        expand_size = input.size(2) * input.size(3)
        expaned_bias = bias.reshape((bias.size(0), 1)).expand(bias.size(0), expand_size)
        output = (input.reshape((input.size(0), input.size(1),-1)) + expaned_bias).reshape(input.size())

    #torch.cuda.synchronize()
    #global_time.bias_time += time.time() - s
    return output
# blind function autograd
class Blind(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mixture, bm, inv_gm):
        torch.cuda.synchronize()
        s = time.time()
        blinded = blind(mixture, bm)
        ctx.save_for_backward(inv_gm)
        torch.cuda.synchronize()
        #global_time.b_time += time.time() - s
        return blinded

    @staticmethod
    def backward(ctx, grad_output):
        # retrive the saved inverse_grad matrix
        inverse_grad, = ctx.saved_tensors

        # decode the blinded gradient
        clean_gradient = unblind(grad_output, inverse_grad)
        return clean_gradient, None, None

# unblind function autograd
class UnBlind(torch.autograd.Function):
    @staticmethod
    def forward(ctx, mixture, um, gm, bias):
        #torch.cuda.synchronize()
        #s = time.time()
        # save gradient matrix for backward pass
        ctx.save_for_backward(gm, bias)
        # return the unblinded input plus the bias
        retval = unblind(mixture, um)
        if bias is not None:
            retval = bias_add(retval, bias)
        #torch.cuda.synchronize()
        #global_time.ub_time += time.time() - s
        return retval
    @staticmethod
    def backward(ctx, grad_output):
        saved_gm, bias = ctx.saved_tensors
        # encode the gradient
        shuffled_grad_output = scramble(grad_output, saved_gm)
        # compute the gradients w.r.t. bias for the next layer
        # 1. get the correct axis
        if grad_output.dim() != 4:
            axis = 0
        else:
            axis = (0, 2, 3)
        # summation of the axis
        if bias is not None:
            bias_gradient = torch.sum(grad_output, axis=axis)
        else:
            bias_gradient = None

        return shuffled_grad_output, None, None, bias_gradient


class Conv2d_sp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        stride=1
        padding=1
        dilation=1
        groups=1
        output = F.conv2d(input, weight, bias, stride, padding, dilation, groups)
        ctx.save_for_backward(input, weight, output)
        return output, bias.detach().clone()

    @staticmethod
    def backward(ctx, grad_weight, grad_bias):
        stride=1
        padding=1
        dilation=1
        groups=1
        input, weight, output = ctx.saved_tensors
        grad_to_input = nn.grad.conv2d_input(input.shape, weight, grad_weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
        grad_to_weight = nn.grad.conv2d_weight(input, weight.shape, grad_weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return grad_to_input, grad_to_weight, grad_bias


class Linear_sp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.save_for_backward(input, weight)
        return F.linear(input, weight, bias), bias

    @staticmethod
    def backward(ctx, grad_weight, grad_bias):
        input, weight = ctx.saved_tensors
        grad_to_input = grad_weight.matmul(weight)
        
        grad_to_weight= grad_weight.transpose(0, 1).matmul(input)
        return grad_to_input, grad_to_weight, grad_bias

