import torch
import torch.nn as nn
from torch.autograd import Function

from ..compression.rank.power_iteration import decompose_tensor_keep_projection
from ..compression.rank.hosvd_power_4_mode import hosvd_power4

def SVD_var(weight, var, use_k=False):
    U, S, Vt = torch.linalg.svd(weight, full_matrices=False)

    if use_k:
        return U[:, :var], torch.diag_embed(S[:var]), Vt[:var, :], U, torch.diag_embed(S), Vt, var

    else:
        total_variance = torch.sum(S**2)
        explained_variance = torch.cumsum(S**2, dim=0) / total_variance
        k = torch.searchsorted(explained_variance, var).item() + 1
        Vt_k = Vt[:k, :]
        return U[:, :k], torch.diag_embed(S[:k]), Vt_k, U, torch.diag_embed(S), Vt, k

#######################################################################

class Linear_WASI4_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, US_k, Vt_k, S_activation, U_list_activation = args

        # Infer output
        output = input@Vt_k.t()@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(S_activation, U_list_activation[0], U_list_activation[1], U_list_activation[2], U_list_activation[3], bias, US_k, Vt_k)
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load the information that is saved from forwardpass
        S, U1, U2, U3, U4, bias, US_k, Vt_k = ctx.saved_tensors

    
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@US_k@Vt_k
        if ctx.needs_input_grad[1]:
            Z1 = torch.einsum("Ba,BHWD->aHWD", U1, grad_output) # Shape: (B, K1) and (B, H, W, D) -> (K1, H, W, D)
            Z2 = torch.einsum("Hb,abcd->aHcd", U2, S) # Shape: (H, K2) and (K1, K2, K3, K4) -> (K1, H, K3, K4)
            Z3 = torch.einsum("Wc,aHWD->aHcD", U3, Z1) # Shape: (W, K3) and (K1, H, W, D) -> (K1, H, K3, D)
            Z4 = torch.einsum("Cd,aHcd->aHCc", U4, Z2) # Shape: (C, K4) and (K1, H, K3, K4) -> (K1, H, C, K3)
            grad_weight = torch.einsum("aHcD,aHCc->DC", Z3, Z4) # Shape: (K1, H, K3, D) and (K1, H, C, K3) -> (D, C)
            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, None, None, None, None
    
class Linear_WASI3_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, US_k, Vt_k, S_activation, U_list_activation = args

        # Infer output
        output = input@Vt_k.t()@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(S_activation, U_list_activation[0], U_list_activation[1], U_list_activation[2], bias, US_k, Vt_k)
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load the information that is saved from forwardpass
        S, U1, U2, U3, bias, US_k, Vt_k = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@US_k@Vt_k
        if ctx.needs_input_grad[1]:
            Z1 = torch.einsum('blo,bk->lok', grad_output, U1) # Shape: B, L, O and B, K1 -> L, O, K1
            Z2 = torch.einsum('abc,lb->acl', S, U2) # Shape: K1, K2, K3 and L, K2 -> K1, K3, L
            Z3 = torch.einsum('acl,ic->ail', Z2, U3) # Shape: K1, K3, L and I, K3 -> K1, I, L
            grad_weight = torch.einsum('lok,kil->oi', Z1, Z3) # Shape: L, O, K1 and K1, I, L -> O, I
                            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, None, None, None, None
    
class Linear_WASI_inference_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, bias, US_k, Vt_k = args

        # Infer output
        output = input@Vt_k.t()@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        pass

class Linear_WASI(nn.Linear):
    def __init__(
            self,
            in_features,
            out_features,
            bias=True,
            device=None,
            dtype=None,
            activate=False,
            activation_ranks=1,
            explained_variance_threshold=1.0):
        super(Linear_WASI, self).__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )
        self.activate = activate
        self.activation_ranks = activation_ranks
        self.explained_variance_threshold = explained_variance_threshold

        self.reuse_map = False

        # for weight
        self.previous_q = None
        self.previous_p = None
        self.previous_k = None
        
        # for activation
        self.u_list_activation = None


    def forward(self, input):
        if self.activate:
            ### Keep projection
            if torch.is_grad_enabled(): # Training mode
                # Decompose activation
                S_activation, self.u_list_activation = hosvd_power4(input, previous_Ulist=self.u_list_activation, reuse_U=self.reuse_map, rank=self.activation_ranks)

                if not self.reuse_map:
                    # Decompose weight
                    self.previous_p, Sk_torch, Vtk_torch, U, S, Vt, k = SVD_var(self.weight.clone().detach(), self.explained_variance_threshold)
                    self.previous_k = k
                    q = (Sk_torch@Vtk_torch).t()
                    self.reuse_map = True

                else:
                    self.previous_p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_map, rank=self.previous_k, device='cuda')


                if input.dim() == 4:
                    output = Linear_WASI4_op.apply(input, self.weight, self.bias, self.previous_p, q.t(), S_activation, self.u_list_activation)
                elif input.dim() == 3:
                    output = Linear_WASI3_op.apply(input, self.weight, self.bias, self.previous_p, q.t(), S_activation, self.u_list_activation)
                else:
                    raise ValueError("Not implemented for input dim = {}".format(input.dim()))

            else:
                p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_map, rank=self.previous_k, device='cuda')
                output = Linear_WASI_inference_op.apply(input, self.bias, self.previous_p, q.t())

        else: # activate is False or Validation mode
            output = super().forward(input)
        return output
    

def wrap_linearWASI(linear, active, activation_ranks, explained_variance_threshold):
    has_bias = (linear.bias is not None)
    new_linear = Linear_WASI(in_features=linear.in_features,
                        out_features=linear.out_features,
                        bias=has_bias,
                        activate=active,
                        activation_ranks=activation_ranks,
                        explained_variance_threshold=explained_variance_threshold
                        )
    

    new_linear.weight.data = linear.weight.data
    if new_linear.bias is not None:
        new_linear.bias.data = linear.bias.data
    return new_linear