import torch

AVAILABLE_PROJECTIONS = ['dct', 'hdm', 'randn-qr']

class FrugalCheapProjection:
    def __init__(self, Q, density, grad_shape, proj_side='std', proj_type="svd", verbose=False):
        self.verbose = verbose
        self.ortho_matrix = None
        if (proj_side == 'right' or
                (proj_side == 'std' and grad_shape[0] >= grad_shape[1]) or
                (proj_side == 'reverse_std' and grad_shape[0] < grad_shape[1])):
            self.proj_side = "right"
            self.rank = round(grad_shape[1] * density) if 0 < density < 1 else int(density)
        elif (proj_side == 'left' or
              (proj_side == 'reverse_std' and grad_shape[0] >= grad_shape[1]) or
              (proj_side == 'std' and grad_shape[0] < grad_shape[1])):

            if grad_shape[0] == 1024 and grad_shape[1] == 4096:
                self.proj_side = "right"
                self.rank = round(grad_shape[1] * density) if 0 < density < 1 else int(density)
            else:
                self.proj_side = "left"
                self.rank = round(grad_shape[0] * density) if 0 < density < 1 else int(density)
        else:
            raise NameError("Wrong proj_side for DCT Projector")
            # self.proj_side = "full"
            # self.rank = round(min(grad_shape) * (density))
        self.proj_type = proj_type
        self.Q = Q
        self.indices_crt = None
        self.indices_prev = None
        self.calls_to_update_proj = 0

    def update_proj(self, full_rank_grad):
        if self.proj_type in AVAILABLE_PROJECTIONS:
            if self.indices_crt is None:
                self.indices_crt = torch.zeros(self.rank, dtype=torch.int32, device=full_rank_grad.device)

            self.calls_to_update_proj += 1

            if self.calls_to_update_proj > 1:
                if self.indices_prev is None:
                    self.indices_prev = torch.zeros(self.rank, dtype=torch.int32, device=full_rank_grad.device)
                self.indices_prev.copy_(self.indices_crt)

            if self.proj_side == "right":
                P = full_rank_grad @ self.Q
                norms = P.norm(p=1, dim=0)
                indices = torch.topk(input=norms, k=self.rank, sorted=False).indices
                self.indices_crt.copy_(indices)
            elif self.proj_side == "left":
                P = self.Q.T @ full_rank_grad
                norms = P.norm(p=1, dim=1)
                indices = torch.topk(input=norms, k=self.rank, sorted=False).indices
                self.indices_crt.copy_(indices)
            else:
                raise RuntimeError('Projection side "full" is not imlpemented for proj_type="dct"!')
        else:
            raise NameError("Wrong proj_type for FrugalCheapProjector")

    def project_down(self, full_rank):
        if self.proj_type in AVAILABLE_PROJECTIONS:
            if self.proj_side == "right":
                low_rank = full_rank @ self.Q[:, self.indices_crt]
            elif self.proj_side == "left":
                low_rank = self.Q[:, self.indices_crt].T @ full_rank
            else:
                raise RuntimeError('Projection side "full" is not imlpemented for proj_type="dct"!')
        else:
            raise NameError("Wrong proj_type for DCT Projector")

        return low_rank

    def project_up(self, low_rank):
        if self.proj_type in AVAILABLE_PROJECTIONS:
            if self.proj_side == 'right':
                full_rank = low_rank @ self.Q[:, self.indices_crt].T
            elif self.proj_side == 'left':
                full_rank = self.Q[:, self.indices_crt] @ low_rank
            else:
                raise RuntimeError('Projection side "full" is not imlpemented for proj_type="dct"!')
        else:
            raise NameError("Wrong proj_type for DCT Projector")

        return full_rank

    def get_subspace_rotation_matrix(self, Q):
        icrt = self.indices_crt
        iprev = self.indices_prev

        if self.proj_side == 'right':
            if iprev is None:
                return None # Q[:, icrt]
            return Q[:, iprev].T @ Q[:, icrt]
        elif self.proj_side == 'left':
            if iprev is None:
                return None # Q[:, icrt].T
            return Q[:, icrt].T @ Q[:, iprev]

    def rotate_subspace(self, R, w):
        if self.proj_side == 'right':
            torch.matmul(w, R, out=w)
        elif self.proj_side == 'left':
            torch.matmul(R, w, out=w)
