import torch
import torch.nn as nn
from torch.autograd import Function
import time
import os
os.environ["OPENLABS_NUM_THREADS"]="1"
torch.set_num_threads(1)
#######################################################################

class Pointwise_conv_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, backward_time, forward_time = args
        
        weight = torch.squeeze(weight)
        b, _, h, w = input.shape
        c_o, c_i = weight.shape[:2]
        input = input.view(b,c_i,h*w).transpose(1,2)

        start_f = time.time()
        output = input@weight.t()
        end_f = time.time()

        forward_time.append(end_f-start_f)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        output = output.transpose(1,2).view(b, c_o, h, w)
        cfgs = torch.tensor([bias is not None,h,w])
        ctx.save_for_backward(input, weight, cfgs)
        ctx.backward_time = backward_time

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, cfgs = ctx.saved_tensors
        bias,h,w = cfgs
        
        c_o, c_i = weight.shape

        grad_output = grad_output.view(grad_output.shape[0],c_o,-1).transpose(1,2)

        backward_time = ctx.backward_time

        start = time.time()

        grad_input = grad_output@weight
        grad_weight = torch.einsum('blo,bli->oi', grad_output, input)
            

        end = time.time()
        backward_time.append(end - start)

        grad_input = grad_input.transpose(1, 2).view(grad_output.shape[0],c_i, h, w)
        grad_weight = grad_weight.view(c_o, c_i, 1, 1)

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(dim=(0, 1))
        else:
            grad_bias= None
        return grad_input, grad_weight, grad_bias, None, None, None, None

class Linear4_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, backward_time, forward_time = args

        start_f = time.time()

        # Infer output
        output = input@weight.t()
        end_f = time.time()

        forward_time.append(end_f-start_f)
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(input, weight, bias)
        ctx.backward_time = backward_time

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        backward_time = ctx.backward_time

        start = time.time()

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@weight
        if ctx.needs_input_grad[1]:
            grad_weight = torch.einsum('bhwc,bhwd->dc', input, grad_output)
            

        end = time.time()
        backward_time.append(end - start)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        return grad_input, grad_weight, grad_bias, None, None
    
class Linear3_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, backward_time, forward_time = args

        start_f = time.time()

        # Infer output
        output = input@weight.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(input, weight, bias)

        end_f = time.time()
        forward_time.append(end_f-start_f)
        ctx.backward_time = backward_time
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None
        backward_time = ctx.backward_time
        start = time.time()

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@weight
        if ctx.needs_input_grad[1]:
            grad_weight = torch.einsum('bli,blo->oi', input, grad_output)
                            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        
        end = time.time()
        backward_time.append(end - start)

        return grad_input, grad_weight, grad_bias, None, None
    
class Linear(nn.Linear):
    def __init__(
            self,
            in_features,
            out_features,
            bias=True,
            device=None,
            dtype=None,
            backward_time=None,
            forward_time=None,
            inference_time=None):
        super(Linear, self).__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )
        self.backward_time = backward_time
        self.forward_time = forward_time
        self.inference_time = inference_time


    def forward(self, input):
        if torch.is_grad_enabled(): # Training mode
            if input.dim() == 4:
                output = Linear4_op.apply(input, self.weight, self.bias, self.backward_time, self.forward_time)
            elif input.dim() == 3:
                output = Linear3_op.apply(input, self.weight, self.bias, self.backward_time, self.forward_time)
            else:
                raise ValueError("Not yet implement")

        else: # Validation mode
            start = time.time()
            output = super().forward(input)
            end = time.time()
            self.inference_time.append(end-start)
        return output
    
class Conv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=1, stride=1, 
            padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros',
            backward_time=None,
            forward_time=None,
            inference_time=None):
        super(Conv2d, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )
        
        self.backward_time = backward_time
        self.forward_time = forward_time
        self.inference_time = inference_time

    def forward(self, input):
        if torch.is_grad_enabled():
            output = Pointwise_conv_op.apply(input, self.weight, self.bias, self.backward_time, self.forward_time)
        else: # Validation mode
            start = time.time()
            output = super().forward(input)
            end = time.time()
            self.inference_time.append(end-start)
        return output
    

def wrap_linear(linear, backward_time, forward_time, inference_time):
    has_bias = (linear.bias is not None)
    new_linear = Linear(in_features=linear.in_features,
                        out_features=linear.out_features,
                        bias=has_bias,
                        backward_time = backward_time,
                        forward_time = forward_time,
                        inference_time = inference_time
                        )
    
    new_linear.weight.data = linear.weight.data
    if new_linear.bias is not None:
        new_linear.bias.data = linear.bias.data
    return new_linear

def wrap_conv(conv2d, backward_time, forward_time, inference_time):
    new_conv2d = Conv2d(in_channels=conv2d.in_channels,
                        out_channels=conv2d.out_channels,
                         kernel_size=conv2d.kernel_size,
                        stride=conv2d.stride,
                        padding=conv2d.padding,
                        bias=conv2d.bias is not None,
                        backward_time = backward_time,
                        forward_time = forward_time,
                        inference_time = inference_time
                        )
    

    new_conv2d.weight.data = conv2d.weight.data
    if new_conv2d.bias is not None:
        new_conv2d.bias.data = conv2d.bias.data
    return new_conv2d