import torch
import torch.nn as nn
from torch.autograd import Function
import time

from ..compression.power_iteration import decompose_tensor, decompose_tensor_keep_projection
from ..compression.hosvd_power_4_mode import hosvd_power4, restore_hosvd_power4

def SVD_var(weight, var, use_k=False):
    U, S, Vt = torch.linalg.svd(weight, full_matrices=False)

    if use_k:
        return U[:, :var], torch.diag_embed(S[:var]), Vt[:var, :], U, torch.diag_embed(S), Vt, var

    else:
        total_variance = torch.sum(S**2)
        explained_variance = torch.cumsum(S**2, dim=0) / total_variance
        k = torch.searchsorted(explained_variance, var).item() + 1
        # US_k = U[:, :k]@torch.diag_embed(S[:k])
        Vt_k = Vt[:k, :]
        return U[:, :k], torch.diag_embed(S[:k]), Vt_k, U, torch.diag_embed(S), Vt, k

#######################################################################

class Linear_WASI4_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, US_k, Vt_k, S_activation, U_list_activation, backward_time, forward_time = args
        
        start_f = time.time()

        # Infer output
        output = input@Vt_k.t()@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(S_activation, U_list_activation[0], U_list_activation[1], U_list_activation[2], U_list_activation[3], bias, US_k, Vt_k)
        end_f = time.time()
        forward_time[-1] += (end_f-start_f)
        ctx.backward_time = backward_time

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load the information that is saved from forwardpass
        S, U1, U2, U3, U4, bias, US_k, Vt_k = ctx.saved_tensors

    
        grad_input = grad_weight = grad_bias = None

        backward_time = ctx.backward_time

        start = time.time()

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@US_k@Vt_k
        if ctx.needs_input_grad[1]:
            Z1 = torch.einsum("Ba,BHWD->aHWD", U1, grad_output) # Shape: (B, K1) and (B, H, W, D) -> (K1, H, W, D)
            Z2 = torch.einsum("Hb,abcd->aHcd", U2, S) # Shape: (H, K2) and (K1, K2, K3, K4) -> (K1, H, K3, K4)
            Z3 = torch.einsum("Wc,aHWD->aHcD", U3, Z1) # Shape: (W, K3) and (K1, H, W, D) -> (K1, H, K3, D)
            Z4 = torch.einsum("Cd,aHcd->aHCc", U4, Z2) # Shape: (C, K4) and (K1, H, K3, K4) -> (K1, H, C, K3)
            grad_weight = torch.einsum("aHcD,aHCc->DC", Z3, Z4) # Shape: (K1, H, K3, D) and (K1, H, C, K3) -> (D, C)
            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        end = time.time()
        backward_time.append(end - start)

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
    
class Linear_WASI3_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, US_k, Vt_k, S_activation, U_list_activation, backward_time, forward_time = args

        start_f = time.time()

        # Infer output
        output = input@Vt_k.t()
        output = output@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(S_activation, U_list_activation[0], U_list_activation[1], U_list_activation[2], bias, US_k, Vt_k)

        end_f = time.time()
        forward_time[-1] += (end_f-start_f)
        ctx.backward_time = backward_time
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load the information that is saved from forwardpass
        S, U1, U2, U3, bias, US_k, Vt_k = ctx.saved_tensors

        grad_input = grad_weight = grad_bias = None

        backward_time = ctx.backward_time
        start = time.time()

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@US_k
            grad_input = grad_input@Vt_k
        if ctx.needs_input_grad[1]:
            Z1 = torch.einsum('blo,bk->lok', grad_output, U1) # Shape: B, L, O and B, K1 -> L, O, K1
            Z2 = torch.einsum('abc,lb->acl', S, U2) # Shape: K1, K2, K3 and L, K2 -> K1, K3, L
            Z3 = torch.einsum('acl,ic->ail', Z2, U3) # Shape: K1, K3, L and I, K3 -> K1, I, L
            grad_weight = torch.einsum('lok,kil->oi', Z1, Z3) # Shape: L, O, K1 and K1, I, L -> O, I
                            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        
        end = time.time()
        backward_time.append(end - start)

        return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
    
class Linear_WASI_inference_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, bias, US_k, Vt_k, inference_time = args

        start = time.time()

        # Infer output
        output = input@Vt_k.t()@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        end = time.time()
        inference_time.append(end-start)
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        pass

class Linear_WASI(nn.Linear):
    def __init__(
            self,
            in_features,
            out_features,
            bias=True,
            device=None,
            dtype=None,
            activate=False,
            rank_activation=1,
            rank_weight=1.0,
            backward_time=None,
            forward_time=None,
            inference_time=None):
        super(Linear_WASI, self).__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )
        self.backward_time = backward_time
        self.forward_time = forward_time
        self.inference_time = inference_time

        self.activate = activate
        self.rank_activation = rank_activation
        self.rank_weight = rank_weight

        self.reuse_map = False

        # for weight
        self.previous_q = None
        self.previous_p = None
        
        # for activation
        self.u_list_activation = None

    def forward(self, input):
        if self.activate:# and torch.is_grad_enabled(): # Training mode
            ### Giữ projection
            if torch.is_grad_enabled(): # Training mode
                # Decompose activation
                S_activation, self.u_list_activation = hosvd_power4(input, previous_Ulist=self.u_list_activation, reuse_U=self.reuse_map, rank=self.rank_activation, device=input.device)

                # start_f = time.time()
                if not self.reuse_map:
                    # Decompose weight
                    self.previous_p, Sk_torch, Vtk_torch, U, S, Vt, k = SVD_var(self.weight.clone().detach(), self.rank_weight)
                    self.previous_k = k
                    q = (Sk_torch@Vtk_torch).t()
                    self.reuse_map = True

                    start_f = time.time()
                    # Decompose weight
                    self.previous_p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_map, rank=self.rank_weight, device=input.device)
                    # self.reuse_map = True
                    end_f = time.time()

                else:
                    start_f = time.time()
                    self.previous_p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_map, rank=self.rank_weight, device=input.device)
                    end_f = time.time()


                # end_f = time.time()
                self.forward_time.append(end_f-start_f)


                if input.dim() == 4:
                    output = Linear_WASI4_op.apply(input, self.weight, self.bias, self.previous_p, q.t(), S_activation, self.u_list_activation, self.backward_time, self.forward_time)
                elif input.dim() == 3:
                    output = Linear_WASI3_op.apply(input, self.weight, self.bias, self.previous_p, q.t(), S_activation, self.u_list_activation, self.backward_time, self.forward_time)
                else:
                    raise ValueError("Chưa triển khai cho input có dim khác 3 hoặc 4")

            else: # không decompose activation và giữ lại subspace map của weight trong quá trình inference
                p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_map, rank=self.rank_weight, device=input.device)
                output = Linear_WASI_inference_op.apply(input, self.bias, self.previous_p, q.t(), self.inference_time)

        else: # activate is False or Validation mode
            output = super().forward(input)
        return output
    

def wrap_linearWASI(linear, active, rank_activation, rank_weight, backward_time, forward_time, inference_time):
    has_bias = (linear.bias is not None)
    new_linear = Linear_WASI(in_features=linear.in_features,
                        out_features=linear.out_features,
                        bias=has_bias,
                        activate=active,
                        rank_activation=rank_activation,
                        rank_weight=rank_weight,
                        backward_time = backward_time,
                        forward_time = forward_time,
                        inference_time = inference_time
                        )
    

    new_linear.weight.data = linear.weight.data
    if new_linear.bias is not None:
        new_linear.bias.data = linear.bias.data
    return new_linear