import torch

class dct_projector:
    def __init__(self, Q, rank, proj_type):
        self.rank = rank
        self.proj_type = proj_type
        self.indices = None
        self.Q = Q

    def project(self, full_rank_grad):
        if self.proj_type == 'right':
            low_rank_grad = torch.matmul(full_rank_grad, self.Q[:, self.indices])
        elif self.proj_type == 'left':
            low_rank_grad = torch.matmul(self.Q[:, self.indices].T, full_rank_grad)
        else:
            raise ValueError(f'Projection type {self.proj_type} is currently not supported')

        return low_rank_grad

    def project_back(self, low_rank_grad):
        if self.proj_type == 'right':
            full_rank_grad = torch.matmul(low_rank_grad, self.Q[:, self.indices].T)
        elif self.proj_type == 'left':
            full_rank_grad = torch.matmul(self.Q[:, self.indices], low_rank_grad)
        else:
            raise ValueError(f'Projection type {self.proj_type} is currently not supported')

        return full_rank_grad

    def get_orthogonal_matrix_dct(self, full_rank_grad):
        if self.indices is None:
            self.indices = torch.zeros(self.rank, dtype=torch.int32, device=full_rank_grad.device)

        if self.proj_type == 'right':
            P = torch.matmul(full_rank_grad, self.Q)
            norms = P.norm(p=1, dim=0)
            indices = torch.topk(input=norms, k=self.rank, sorted=False).indices
            self.indices.copy_(indices)
            # low_rank_grad = P[:, self.indices]
            del P, norms, indices
        elif self.proj_type == 'left':
            P = torch.matmul(self.Q.T, full_rank_grad)
            norms = P.norm(p=1, dim=1)
            indices = torch.topk(input=norms, k=self.rank, sorted=False).indices
            self.indices.copy_(indices)
            # low_rank_grad = P[self.indices, :]
            del P, norms, indices

class projector:
    def __init__(self, rank, proj_type='std'):
        self.rank = rank
        self.proj_type = proj_type
        self.ortho_matrix = None

    def project(self, full_rank_grad):
        if self.proj_type == 'right':
            low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
        elif self.proj_type == 'left':
            low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)

        return low_rank_grad

    def project_back(self, low_rank_grad):
        if self.proj_type == 'right':
            full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
        elif self.proj_type == 'left':
            full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)

        return full_rank_grad

    # svd decomposition
    def get_orthogonal_matrix_svd(self, grad, svd_lowrank=False):
        matrix = grad.data.clone()

        if matrix.dtype != torch.float: #torch.linalg.svd doesn't support half precision types such as torch.bfloat16
            float_data = False
            original_type = matrix.dtype
            original_device = matrix.device
            matrix = matrix.float()
        else :
            float_data = True

        if svd_lowrank :
            U, s, V = torch.svd_lowrank(matrix, q = self.rank+2, niter=1) #q a slightly overestimated rank of A
            Vh = V.t()
        else :
            U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)

        if self.proj_type =='right' :
            ortho_matrix = Vh[:self.rank, :]
        elif self.proj_type=='left' :
            ortho_matrix = U[:, :self.rank]

        if not(float_data) :
            ortho_matrix = ortho_matrix.to(device=original_device, dtype=original_type)

        self.ortho_matrix = ortho_matrix

    def power_iteration(self, matrix, init, intermediate_orthogonalization=False):
        if self.proj_type == 'right':
            U = matrix @ init.t()
            if intermediate_orthogonalization : # Not necessary for computing right singular vectors
                U = Gram_Schmidt(U)
            projection_map = matrix.t() @ U
            del U

            projection_map = Gram_Schmidt(projection_map)
            self.ortho_matrix = projection_map.t()

        elif self.proj_type == 'left':
            V = matrix.t() @ init
            if intermediate_orthogonalization : # Not necessary for computing left singular vectors
                V = Gram_Schmidt(V)
            projection_map = matrix @ V
            del V

            projection_map = Gram_Schmidt(projection_map)
            self.ortho_matrix = projection_map

def Gram_Schmidt(matrix):
    original_type = matrix.dtype #torch.linalg.qr doesn't support helf precision types such as torch.bfloat16
    matrix, _ = torch.linalg.qr(matrix.to(dtype=torch.float32))
    matrix = matrix.to(dtype=original_type)

    return matrix