import torch

AVAILABLE_PROJECTIONS = ['dct', 'hdm', 'randn-qr']

class CheapProjector:
    def __init__(
        self, rank, verbose=False, update_proj_gap=200, alpha=1.0, proj_type="std"
    ):
        self.rank = rank
        self.verbose = verbose
        self.update_proj_gap = update_proj_gap
        self.alpha = alpha
        self.indices = None
        self.proj_type = proj_type
        self.is_right_proj = None

    def project(self, Q, full_rank_grad, iter):
        if self.proj_type in AVAILABLE_PROJECTIONS:
            if self.indices is None:
                self.indices = torch.zeros(self.rank, dtype=torch.int32, device=full_rank_grad.device)

            n, m = full_rank_grad.shape
            if self.is_right_proj is None:
                if n >= m:
                    self.is_right_proj = True
                else:
                    if n == 1024 and m == 4096:
                        self.is_right_proj = True
                    else:
                        self.is_right_proj = False

            if self.is_right_proj:
                if iter == 0 or iter % self.update_proj_gap == 0:
                    P = full_rank_grad @ Q
                    norms = P.norm(p=1, dim=0)
                    indices = torch.topk(input=norms, k=self.rank, sorted=False).indices
                    self.indices.copy_(indices)
                    low_rank_grad = P[:, self.indices]
                else:
                    low_rank_grad = full_rank_grad @ Q[:, self.indices]
                # low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
            else:
                if iter == 0 or iter % self.update_proj_gap == 0:
                    P = Q.T @ full_rank_grad
                    norms = P.norm(p=1, dim=1)
                    indices = torch.topk(input=norms, k=self.rank, sorted=False).indices
                    self.indices.copy_(indices)
                    low_rank_grad = P[self.indices, :]
                else:
                    low_rank_grad = Q[:, self.indices].T @ full_rank_grad
                # low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
        else:
            raise ValueError(f'Projection type {self.proj_type} is currently not supported')

        return low_rank_grad

    def project_back(self, Q, low_rank_grad):
        if self.proj_type in AVAILABLE_PROJECTIONS:

            if self.is_right_proj:
                # full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
                full_rank_grad = low_rank_grad @ Q[:, self.indices].T
            else:
                full_rank_grad = Q[:, self.indices] @ low_rank_grad
                # full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
        else:
            raise ValueError(f'Projection type {self.proj_type} is currently not supported')

        return full_rank_grad * self.alpha
