import math

import torch

#########################
# Singleton Seed Sampler
#########################
class Singleton(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class Manager(metaclass=Singleton):
    def init_iter(self):
        self.iter = 0
        self.global_iter = 0

    def update_iter(self):
        self.iter += 1

    def get_iter(self):
        return self.iter

    def update_global(self):
        self.global_iter += 1

    def get_global_iter(self):
        return self.global_iter


class SeedSetter(metaclass=Singleton):
    """
    Constant sampler - always returns seed=current update step
    The seed is updated externally from the training loop.
    """

    def __init__(self):
        self.current_rng_state = None
        self.g = None

    def set_seed(self, seed, device):
        self.current_rng_state = torch.get_rng_state()
        self.g = torch.Generator(device=device)
        self.g.manual_seed(seed)

    def reset_seed(self):
        assert self.current_rng_state is not None, "SeedSetter has not been initialized with set_seed."
        torch.set_rng_state(self.current_rng_state)

    def get_generator(self):
        assert self.g is not None, "SeedSetter's generator is not initialized. Call set_seed first."
        return self.g


def rng():
    return SeedSetter().get_generator()

##################
# PAMM Projector Class
##################
class PAMMProjector:
    def __init__(self, rank, proj_type, gap, layer_name, save_dir, epsilon):
        self.rank = rank
        self.proj_type = proj_type
        self.seed = torch.randint(1, int(1e6), (1,)).item()
        self.b = None
        self.seq_ln = None
        self.n = None
        self.gap = gap
        self.last_update = None
        self.cdot = False
        self.epsilon = epsilon

        # Needed for debug grads
        self.num_of_step = 0
        self.layer_name = layer_name
        self.save_dir = save_dir
        # -------------

        # For random norm sequence
        self.indices = None

    def project(self, x, training):
        """
        x is of size: (B, L, N)
        saves for backward (r1(bl),n)
        """
        if self.proj_type == 'id':
            return x

        self.b, self.seq_ln, self.n = x.shape
        iter = Manager().get_iter()

        if self.proj_type == 'pamm':
            B, L, D = x.shape
            total_rows = B * L
            x_2d = x.view(total_rows, D)  # Flatten

            is_forward = (self.indices is None)

            if is_forward:
                if (self.last_update is None) or (self.last_update < iter):
                    if (self.gap is not None) and ((iter % self.gap) == 0):
                        self.seed += 1
                        self.last_update = iter
                gen = torch.Generator(device=x.device)
                gen.manual_seed(self.seed)

                # 2) K = rank * total_rows
                K = int(self.rank * total_rows)
                self.K = K
                if K > total_rows:
                    raise ValueError(f"K={K} cannot exceed total_rows={total_rows}")

                # Randomly pick K distinct row indices
                idx_rand = torch.randperm(total_rows, generator=gen, device=x.device)[:K]
                x_norms = x_2d.norm(dim=1, keepdim=True)
                x_2d = x_2d / (x_norms + 1e-10)
                centroids = x_2d[idx_rand]

                cdot = x_2d @ centroids.T
                cosines, assignments = cdot.max(dim=1)  # max cos(theta) --> minimal theta

                alphas = x_norms * cosines.unsqueeze(-1)

                if training:
                    self.indices = assignments
                    self.alphas = alphas

                return centroids

            else:
                # -------------- BACKWARD PASS --------------
                self.indices = None
                self.alphas=None
                B, L, D = x.shape
                total_rows = B * L
                x_2d = x.view(total_rows, D)
                return x_2d

        elif self.proj_type == 'pamm_epsilon':
            B, L, D = x.shape
            total_rows = B * L
            x_2d = x.reshape(total_rows, D)  # Flatten

            is_forward = (self.indices is None)

            if is_forward:
                if (self.last_update is None) or (self.last_update < iter):
                    if (self.gap is not None) and ((iter % self.gap) == 0):
                        self.seed += 1
                        self.last_update = iter
                gen = torch.Generator(device=x.device)
                gen.manual_seed(self.seed)

                # 2) K = rank * total_rows
                K = max(1, int(self.rank * total_rows))
                self.K = K
                if K > total_rows:
                    raise ValueError(f"K={K} cannot exceed total_rows={total_rows}")

                # Randomly pick K distinct row indices
                idx_rand = torch.randperm(total_rows, generator=gen, device=x.device)[:K]
                copy_x_2d = x_2d

                x_norms = x_2d.norm(dim=1, keepdim=True)
                x_2d = x_2d / (x_norms + 1e-10)
                centroids = x_2d[idx_rand]

                copy_centroids = copy_x_2d[idx_rand]

                cdot = x_2d @ centroids.T
                
                cosines, assignments = cdot.max(dim=1)  # max cos(theta) --> minimal theta

                assigned_centroids = copy_centroids[assignments]
                relative_norm = (copy_x_2d - assigned_centroids).norm(dim=1, keepdim=True)
                final_norm = relative_norm / (x_norms + 1e-10)
                selected_indices = final_norm <= self.epsilon

                alphas = x_norms * cosines.unsqueeze(-1)

                if training:
                    self.selected_indices = selected_indices.squeeze()
                    self.indices = assignments
                    self.alphas = alphas

                return centroids

            else:
                # -------------- BACKWARD PASS --------------
                self.indices = None
                self.alphas=None
                selected_indices = self.selected_indices
                self.selected_indices = None
                B, L, D = x.shape
                total_rows = B * L
                x_2d = x.view(total_rows, D)
                return x_2d[selected_indices]
