import torch as th
import torch.nn as nn
from torch.autograd import Function
import torch.nn.functional as F
from typing import Any
import math
import os
import time

os.environ["OPENLABS_NUM_THREADS"]="1"
th.set_num_threads(1)
class LinearSVDOp(Function):
    @staticmethod
    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
        Function.jvp(ctx, *grad_inputs)

    @staticmethod
    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
        x, weight, bias, CompressionTensor_x, CompressionTensor_gy, backward_time, forward_time = args
        
        start = time.time()

        if x.shape[2] < x.shape[1]:
            x_compressed = th.einsum("blc,rc->blr", x, CompressionTensor_x)
        else:
            x_compressed = CompressionTensor_x@x

        y = F.linear(x, weight, bias)
        
        end = time.time()

        forward_time.append(end-start)
        
        """Save for backward"""
        ctx.backward_time = backward_time
        ctx.input_shape = x.shape
        ctx.save_for_backward(x_compressed, weight, th.tensor([bias is not None]), CompressionTensor_x, CompressionTensor_gy )
        return y

    @staticmethod
    def backward(ctx: Any, *grad_outputs: Any) -> Any:
        shape_x = ctx.input_shape 
        x_compressed, weight, cfgs, CompressionTensor_x, CompressionTensor_gy = ctx.saved_tensors

        backward_time = ctx.backward_time

        has_bias = cfgs
        (grad_y,) = grad_outputs

        start = time.time()
        
        if grad_y.shape[2] < grad_y.shape[1]:
            grad_y_compress = th.einsum("blo,ro->blr", grad_y, CompressionTensor_gy)
            weight_compress = th.einsum("ro,oi->ri", CompressionTensor_gy, weight)
            grad_x = th.einsum("blr,ri->bli", grad_y_compress, weight_compress)
        else:
            grad_y_compress = th.einsum("rl, blo ->bro", CompressionTensor_gy, grad_y)
            weight_compress = th.einsum("bro,oi->bri", grad_y_compress, weight)
            grad_x = th.einsum("rl,bri->bli", CompressionTensor_gy, weight_compress)

        if shape_x[2] < shape_x[1]:
            grad_w_temp_1 = th.einsum('blo,blr->bor', grad_y, x_compressed)
            grad_w = th.einsum('bor,ri->oi', grad_w_temp_1, CompressionTensor_x)
        else:
            grad_w_temp_1 = th.einsum('blo,rl->bor', grad_y, CompressionTensor_x)
            grad_w = th.einsum('bor,bri->oi', grad_w_temp_1, x_compressed)
            
        end = time.time()
        backward_time.append(end - start)

        if has_bias:
            grad_b = grad_y.sum(dim=(0, 1))
        else:
            grad_b = None

        return grad_x, grad_w, grad_b, None, None, None, None
    

class LinearCompressClass(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
        backward_time=None,
        forward_time=None,
        inference_time=None,
    ) -> None:
        super().__init__(in_features, out_features, bias, device, dtype)
        self.backward_time=backward_time
        self.forward_time=forward_time
        self.inference_time=inference_time
    
    def forward(self, input: th.Tensor) -> th.Tensor:
        if self.activating and self.training:
            y = LinearSVDOp.apply(
                input,
                self.weight,
                self.bias,
                self.CompressionTensor_x.T.clone().detach().requires_grad_(False).to(input.device),
                self.CompressionTensor_gy.T.clone().detach().requires_grad_(False).to(input.device),
                self.backward_time, self.forward_time
            )
        else:
            y = super().forward(input)
        return y

class PointwiseConvSVDOp(Function):
    @staticmethod
    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
        Function.jvp(ctx, *grad_inputs)

    @staticmethod
    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
        input, weight, bias, CompressionTensor_x, CompressionTensor_gy, backward_time, forward_time = args
        
        start = time.time()
        
        b, _, h, w = input.shape
        c_o, c_i = weight.shape[:2]
        weight = th.squeeze(weight)
        x = input.view(b,c_i,h*w).transpose(1,2)

        if x.shape[2] < x.shape[1]:
            x_compressed = th.einsum("blc,rc->blr", x, CompressionTensor_x)
        else:
            x_compressed = CompressionTensor_x@x

        y = F.linear(x, weight, bias)
        
        end = time.time()
        forward_time.append(end-start)
        
        y = y.transpose(1,2).view(b, c_o, h, w)
        cfgs = th.tensor([bias is not None,h,w])

        """Save for backward"""
        ctx.input_shape = x.shape
        ctx.backward_time = backward_time
        ctx.save_for_backward(x_compressed, weight, cfgs, CompressionTensor_x, CompressionTensor_gy )

        return y

    @staticmethod
    def backward(ctx: Any, *grad_outputs: Any) -> Any:

        shape_x = ctx.input_shape 
        backward_time = ctx.backward_time

        x_compressed, weight, cfgs, CompressionTensor_x, CompressionTensor_gy  = ctx.saved_tensors
        (grad_y,) = grad_outputs
        b, L, c_i = shape_x
        c_o, c_i = weight.shape
        has_bias,h,w = cfgs
        grad_y = grad_y.view(b,c_o,-1).transpose(1,2)
        
        start = time.time()
        
        if grad_y.shape[2] < grad_y.shape[1]:
            grad_y_compress = th.einsum("blo,ro->blr", grad_y, CompressionTensor_gy)
            weight_compress = th.einsum("ro,oi->ri", CompressionTensor_gy, weight)
            grad_x = th.einsum("blr,ri->bli", grad_y_compress, weight_compress)
        else:
            grad_y_compress = th.einsum("rl, blo ->bro", CompressionTensor_gy, grad_y)
            weight_compress = th.einsum("bro,oi->bri", grad_y_compress, weight)
            grad_x = th.einsum("rl,bri->bli", CompressionTensor_gy, weight_compress)

        if shape_x[2] < shape_x[1]:
            grad_w_temp_1 = th.einsum('blo,blr->bor', grad_y, x_compressed)
            grad_w = th.einsum('bor,ri->oi', grad_w_temp_1, CompressionTensor_x)
        else:
            grad_w_temp_1 = th.einsum('blo,rl->bor', grad_y, CompressionTensor_x)
            grad_w = th.einsum('bor,bri->oi', grad_w_temp_1, x_compressed)
        end = time.time()
        backward_time.append(end-start)
       
        # reshape grad_x, grad_w
        grad_x = grad_x.transpose(1, 2).view(b,c_i, h, w)
        grad_w = grad_w.view(c_o, c_i, 1, 1)

        if has_bias:
            grad_b = grad_y.sum(dim=(0, 1))
        else:
            grad_b = None
        return grad_x, grad_w, grad_b, None, None, None, None


class PointwiseConvCompressClass(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', backward_time=None,
                forward_time=None,
                inference_time=None):
        # Call the parent class (nn.Conv2d) constructor
        super(PointwiseConvCompressClass, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )
    
        self.backward_time = backward_time
        self.forward_time = forward_time
        self.inference_time = inference_time

    def forward(self, input: th.Tensor) -> th.Tensor:
        if self.activating and self.training:
            y = PointwiseConvSVDOp.apply(
                input,
                self.weight,
                self.bias,
                self.CompressionTensor_x.T.clone().detach().requires_grad_(False).to(input.device),
                self.CompressionTensor_gy.T.clone().detach().requires_grad_(False).to(input.device),
                self.backward_time, self.forward_time,
            )
        else:
            y = super().forward(input)
        return y

def wrap_pointwise_conv_compression_layer(
    conv_layer: nn.Conv2d, backward_time, forward_time, inference_time, **kwargs
):
    new_op = PointwiseConvCompressClass(
        in_channels=conv_layer.in_channels,
        out_channels=conv_layer.out_channels,
        kernel_size=conv_layer.kernel_size,
        stride=conv_layer.stride,
        padding=conv_layer.padding,
        bias=conv_layer.bias is not None,
        backward_time = backward_time,
        forward_time = forward_time,
        inference_time = inference_time
    )
    new_op.weight.data = conv_layer.weight.data
    if new_op.bias is not None:
        new_op.bias.data = conv_layer.bias.data
    return new_op

def wrap_linear_compression_layer(
    linear_layer: nn.Linear, backward_time, forward_time, inference_time, **kwargs
):
    new_linear = LinearCompressClass(
        in_features=linear_layer.in_features,
        out_features=linear_layer.out_features,
        bias=linear_layer.bias is not None,
        backward_time = backward_time,
        forward_time = forward_time,
        inference_time=inference_time
    )
    new_linear.weight.data = linear_layer.weight.data
    if linear_layer.bias is not None:
        new_linear.bias.data = linear_layer.bias.data
    return new_linear