import math

import torch
import os
import matplotlib.pyplot as plt

class SmallLinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias, single_gpu, training):
        output = input @ weight.T

        if bias is not None:
            output += bias  # += M

        # save tensors for backprop
        compressed_input = weight.projector.project(input, training)
        ctx.save_for_backward(compressed_input, weight, bias)
        ctx.single_gpu = single_gpu
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # unpack saved tensors
        input, weight, bias = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        # calculate gradients
        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight

        if ctx.needs_input_grad[1]:
            if weight.projector.proj_type == 'pamm':
                input = input[weight.projector.indices]
                input *= weight.projector.alphas
            elif weight.projector.proj_type == 'pamm_epsilon':
                input = input[weight.projector.indices]
                input = input[weight.projector.selected_indices]
                input *= weight.projector.alphas[weight.projector.selected_indices]
                scaling = weight.projector.indices.shape[0]/weight.projector.selected_indices.sum()
                input *= scaling
            compressed_grad = weight.projector.project(grad_output, True)
            if len(compressed_grad.shape) == 3:
                grad_weight = torch.einsum('bij,bik->jk', (compressed_grad, input))
            else:
                grad_weight = torch.einsum('bi,bj->ij', (compressed_grad, input))

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(dim=0).sum(dim=0) if len(grad_output.shape) == 3 else grad_output.sum(dim=0)

        return grad_input, grad_weight, grad_bias, None, None


# wrapping the apply function, to allow keyword arguments
def small_linear(input, weight, bias, single_gpu, training):
    return SmallLinearFunction.apply(input, weight, bias, single_gpu, training)
