import torch
import torch.distributed as dist

class CheapLowRankProjector:
    def __init__(self, p, rank, proj, rotate_subspace):
        self.rank = rank
        self.proj = proj
        self.rotate_states = rotate_subspace # allocate indices_pref only if we choose to rotate the subspace

        self.size = None
        self.indices_crt = None # the indices for the columns/rows
        self.indices_prev = None  # the indices for the columns/rows
        self.is_right_proj = None

        self.steps = 0
        self.device = f'cuda:{dist.get_rank()}' if dist.is_initialized() else 'cuda:0'

        self._setup(p)

    def _setup(self, p):
        n, m = p.shape
        if n >= m:
            self.is_right_proj = True
            self.size = min(n, m)
        else: # fix for Llama-3-8B that has a layer of size (1024, 4096)
            if n == 1024 and m == 4096:
                self.is_right_proj = True
                self.size = 4096
            else:
                self.is_right_proj = False
                self.size = min(n, m)
        # self.is_right_proj = (n >= m) or (n < m and self.size == m)

        self.indices_crt = torch.zeros(self.rank, dtype=torch.int32, device=p.device)
        if self.rotate_states:
            self.indices_prev = torch.zeros(self.rank, dtype=torch.int32, device=p.device)

    def inc_step(self):
        self.steps += 1

    def change_subspace(self, Q, A, out=None):
        """
            This method computes P = A @ Q or P = Q.T @ A and then ranks the columns/rows of matrix P.
            Once we determine the most important r indices, we can simply select them directly from P
        without having to multiply again A @ Q[:, indices] or Q[indices, :] @ A.
            This way, we save some computations.
        """
        # if self.steps == 1 or self.steps % self.update_proj_gap == 0:
        if self.steps > 1:
            if self.rotate_states:
                self.indices_prev.copy_(self.indices_crt)

        if self.is_right_proj:
            P = A @ Q
            norms = P.norm(p=1, dim=0) # dim = 0 computes norm of columns (over all rows)
        else:
            P = Q.T @ A
            norms = P.norm(p=1, dim=1) # dim = 1 computes norm of rows (over all columns)

        indices = torch.topk(
            input=norms,
            k=self.rank,
            sorted=False,
        ).indices

        self.indices_crt.copy_(indices)
        del indices, norms

        if out is None:
            if self.is_right_proj:
                return P[:, self.indices_crt]
            else:
                return P[self.indices_crt, :]
        else:
            if self.is_right_proj:
                out.copy_(P[:, self.indices_crt])
            else:
                out.copy_(P[self.indices_crt, :])

    def get_subspace_rotation_matrix(self, Q):
        assert self.rotate_states, f'The optimizer was not initialized with rotate_subspace=True'

        icrt = self.indices_crt
        iprev = self.indices_prev

        if self.is_right_proj:
            return Q[:, iprev].T @ Q[:, icrt] # (m, r).T @ (m, r) = (r, r) # with Q from MatrixStorage @ PhD #11, page 44 (same as with Qfrom optimizer state @ PhD #11, page 47)
            # return Q[iprev, :] @ Q[icrt, :].T # (r, m) @ (r, m).T = (r, r)
        else:
            # return Q[icrt, :] @ Q[iprev, :].T # (r, n) @ (r, n).T = (r, r) # with Q from MatrixStorage @ PhD #11, page 44
            return Q[:, icrt].T @ Q[:, iprev]  # (r, n) @ (r, n).T = (r, r) # with Q from optimizer state @ PhD #11, page 47
            # return Q[:, icrt].T @ Q[:, iprev] # (n, r).T @ (n, r) = (r, r)

    def rotate_subspace(self, R, w):
        assert self.rotate_states, f'The optimizer was not initialized with rotate_subspace=True'
        if self.is_right_proj:
            torch.matmul(w, R, out=w)
        else:
            torch.matmul(R, w, out=w)

    def from_higher_to_lower_dimensions(self, Q, X):
        # Q = MatrixStorage.get_matrix(self.size, self.proj, X.dtype, transpose=not self.is_right_proj)

        icrt = self.indices_crt

        if self.is_right_proj:
            return X @ Q[:, icrt] # (n, m) @ (m, r) = (n, r)
        else:
            # return Q[icrt, :] @ X # (r, n) @ (n, m) = (r, m) # with Q from MatrixStorage @ PhD #11, page 44
            return Q[:, icrt].T @ X # (n, r).T @ (n, m) = (r, m) # with Q from optimizer state @ PhD #11, page 47

    def from_lower_to_higher_dimensions(self, Q, x, out=None):
        # Q = MatrixStorage.get_matrix(self.size, self.proj, x.dtype, transpose=not self.is_right_proj)
        icrt = self.indices_crt

        if self.is_right_proj:
            # (n, r) @ (m, r).T = (n, m)
            if out is None:
                return x @ Q[:, icrt].T
            else:
                torch.matmul(x, Q[:, icrt].T, out=out)
        else:
            # (r, n).T @ (r, m) = (n, m)
            if out is None:
                # return Q[icrt, :].T @ x # with Q from MatrixStorage @ PhD #11, page 44
                return Q[:, icrt] @ x # with Q from optimizer state @ PhD #11, page 47
            else:
                # torch.matmul(Q[icrt, :].T, x, out=out) # with Q from MatrixStorage @ PhD #11, page 44
                torch.matmul(Q[:, icrt], x, out=out) # with Q from optimizer state @ PhD #11, page 47
