import torch
import torch.nn as nn
from torch.autograd import Function
import numpy as np

from ..compression.rank.power_iteration import decompose_tensor_keep_projection

def SVD_var(weight, var, use_k=False):
    U, S, Vt = torch.linalg.svd(weight, full_matrices=False)

    if use_k:
        return U[:, :var], torch.diag_embed(S[:var]), Vt[:var, :], U, torch.diag_embed(S), Vt, var

    else:
        total_variance = torch.sum(S**2)
        explained_variance = torch.cumsum(S**2, dim=0) / total_variance
        k = torch.searchsorted(explained_variance, var).item() + 1
        Vt_k = Vt[:k, :]
        return U[:, :k], torch.diag_embed(S[:k]), Vt_k, U, torch.diag_embed(S), Vt, k, S
    

#######################################################################

class Linear_WSI_op(Function):
    @staticmethod
    def forward(ctx, *args):
        input, weight, bias, US_k, Vt_k = args
        
        # Low rank forward pass
        output = input@Vt_k.t()@US_k.t()
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        
        ctx.save_for_backward(input, bias, US_k, Vt_k)
        
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Load the information that is saved from forward pass
        input, bias, US_k, Vt_k = ctx.saved_tensors

    
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output@US_k@Vt_k
        if ctx.needs_input_grad[1]:
            if input.dim() == 4:
                grad_weight = torch.einsum('bhwc,bhwd->dc', input, grad_output)
            elif input.dim() == 3:
                grad_weight = torch.einsum('bli,blo->oi', input, grad_output)
            else:
                raise ValueError("Chưa triển khai cho input có dim khác 3 hoặc 4")

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias, None, None

class Linear_WSI(nn.Linear):
    def __init__(
            self,
            in_features,
            out_features,
            bias=True,
            device=None,
            dtype=None,
            activate=False,
            rank=1,
            size = None,
            layer_idx=None,
            WSI_with_sub_iter=True):
        super(Linear_WSI, self).__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )
        self.activate = activate
        self.rank = rank
        self.reuse_q = False
        self.previous_q = None
        self.previous_p = None

        self.WSI_with_sub_iter = WSI_with_sub_iter


        self.previous_weight = None
        self.previous_k = None

        self.size = size
        self.layer_idx = layer_idx

    def forward(self, input):
        if self.activate:
            if self.WSI_with_sub_iter:
                ### Keep projection
                if torch.is_grad_enabled(): 
                    if not self.reuse_q:
                        self.previous_p, Sk_torch, Vtk_torch, U, S, Vt, k, _ = SVD_var(self.weight.clone().detach(), self.rank)
                        
                        self.previous_k = k
                        q = (Sk_torch@Vtk_torch).t()
                        self.reuse_q = True
                    else:
                        self.previous_p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_q, rank=self.previous_k, device='cuda')

                    if torch.is_grad_enabled() and self.size is not None:
                        self.size[0].append(self.weight.shape[0])
                        self.size[1].append(self.previous_p.shape[1])
                        self.size[2].append(self.weight.shape[1])
                        self.size[3].append(input.shape)
                    output = Linear_WSI_op.apply(input, self.weight, self.bias, self.previous_p, q.t())
                else:
                    p, q = decompose_tensor_keep_projection(self.weight, previous_p=self.previous_p, reuse_p=self.reuse_q, rank=self.previous_k, device='cuda')
                    output = Linear_WSI_op.apply(input, self.weight, self.bias, self.previous_p, q.t())
            else: # SVD all
                p, Sk_torch, Vtk_torch, U, S, Vt, k, eigen_values = SVD_var(self.weight.clone().detach(), self.rank)
                q = (Sk_torch@Vtk_torch).t()

                if torch.is_grad_enabled() and self.size is not None:
                    self.size[0].append(self.weight.shape[0])
                    self.size[1].append(p.shape[1])
                    self.size[2].append(self.weight.shape[1])
                    self.size[3].append(input.shape)
                    self.size[4][self.layer_idx].append(eigen_values)

                output = Linear_WSI_op.apply(input, self.weight, self.bias, p, q.t())


        else: # activate is False or Validation mode
            output = super().forward(input)
        return output
    

def wrap_linearWSI(linear, active, rank, size, layer_idx, WSI_with_sub_iter=True):
    has_bias = (linear.bias is not None)
    new_linear = Linear_WSI(in_features=linear.in_features,
                        out_features=linear.out_features,
                        bias=has_bias,
                        activate=active,
                        rank=rank,
                        WSI_with_sub_iter = WSI_with_sub_iter,
                        size=size,
                        layer_idx=layer_idx
                        )
    new_linear.weight.data = linear.weight.data
    if new_linear.bias is not None:
        new_linear.bias.data = linear.bias.data
    return new_linear