import torch
from tqdm import tqdm
from typing import Union
import numpy as np

def shuffle_2d_tensor(x: torch.Tensor) -> torch.Tensor:
    
    flat = x.reshape(-1)
    # Generate a random permutation of indices
    indices = torch.randperm(flat.shape[0])
    
    # Use index_select to shuffle along the specified dimension
    shuffled_tensor = flat.index_select(0,indices)
    return shuffled_tensor.reshape(x.shape[0], x.shape[1])

class EffectiveDimensionality:
    def __init__(
        self, flatten=True, device="cuda:0", batch_size: Union[None, int] = 64, progress: bool =  False, marchenko = True
    ):
        self.flatten = flatten
        self.device = device
        self.batch_size = batch_size
        self.progress = progress
        self.use_numpy_backend = False
        self.marchenko = marchenko

    def compute(self, x):
        if self.flatten:
            x = x.view(len(x), -1)

        assert (
            x.ndim == 2
        ), f"Expected number of dimensions on x to be 2 but got: {x.ndim}"

        if self.marchenko:
            eigen_values, eigen_values_shufled = self.get_eigenvalues_torch(x)
            effective_dim = self.get_effective_dim_marchenko(
                eigen_values=eigen_values,
                eigen_values_shuffled=eigen_values_shufled
            )
        else:
            eigen_values = self.get_eigenvalues_torch(x)
            effective_dim = self.get_effective_dim(eigen_values)

        return effective_dim

    def get_eigenvalues_torch_original(self, x):
        x = x.to(self.device)

        # Compute the mean of the input data
        mean = torch.mean(x, dim=0)

        # Center the data
        centered_x = x - mean

        # Compute the covariance matrix
        cov_matrix = torch.matmul(centered_x.t(), centered_x) / (x.size(0) - 1)

        if self.use_numpy_backend:
            # Compute eigenvalues in descending order
            eigenvalues, _ = np.linalg.eigh(cov_matrix.cpu().numpy())
            eigenvalues = torch.tensor(eigenvalues).to(self.device)
        else:
            # Compute eigenvalues in descending order
            eigenvalues, _ = torch.linalg.eigh(cov_matrix)

        # Sort eigenvalues in descending order
        eigenvalues = eigenvalues.flip(0)

        return eigenvalues

    def get_eigenvalues_torch_batched(self, x):
        device = x.device  # Get the device of input tensor

        # Compute the mean of the input data directly on GPU
        mean = torch.mean(x, dim=0)

        # Center the data in-place
        x -= mean.unsqueeze(0)

        # Compute the covariance matrix in smaller batches

        cov_matrix = torch.zeros(x.shape[1], x.shape[1], device=device)

        for i in tqdm(
            range(0, x.size(0), self.batch_size),
            desc=f"Calculating covariance matrix with batch size {self.batch_size}",
            disable = not(self.progress)
        ):
            x_batch = x[i : i + self.batch_size]
            cov_matrix += torch.matmul(x_batch.t(), x_batch)

        cov_matrix /= x.size(0) - 1

        if self.marchenko:
            # shuffle x and find another covariance matrix
            cov_matrix_marchenko = torch.zeros(x.shape[1], x.shape[1], device=device)
            x_shuffled = shuffle_2d_tensor(x=x)
            for i in tqdm(
                range(0, x_shuffled.size(0), self.batch_size),
                desc=f"Calculating covariance matrix with batch size {self.batch_size}",
                disable = not(self.progress)
            ):
                x_batch = x_shuffled[i : i + self.batch_size]
                cov_matrix_marchenko += torch.matmul(x_batch.t(), x_batch)
            cov_matrix_marchenko /= x_shuffled.size(0) - 1


        if self.use_numpy_backend:
            # Compute eigenvalues in descending order
            eigenvalues, _ = np.linalg.eigh(cov_matrix.cpu().numpy())
            
            if self.marchenko:
                eigenvalues_shuffled, _ = np.linalg.eigh(cov_matrix_marchenko.cpu().numpy())
                eigenvalues_shuffled = torch.tensor(eigenvalues_shuffled).to(self.device)

            eigenvalues = torch.tensor(eigenvalues).to(self.device)
        else:
            # Compute eigenvalues in descending order
            eigenvalues, _ = torch.linalg.eigh(cov_matrix)

            if self.marchenko:
                eigenvalues_shuffled, _ = torch.linalg.eigh(cov_matrix_marchenko)
                eigenvalues_shuffled = eigenvalues_shuffled.flip(0)

        # Sort eigenvalues in descending order
        eigenvalues = eigenvalues.flip(0)

        if self.marchenko:
            return eigenvalues, eigenvalues_shuffled
        else:
            return eigenvalues

    def get_eigenvalues_torch(self, x):
        if self.batch_size is not None:
            return self.get_eigenvalues_torch_batched(x=x)
        else:
            return self.get_eigenvalues_torch_original(x=x)

    def get_effective_dim(self, eigen_values):
        squared_sum = torch.sum(eigen_values) ** 2
        sum_of_squares = torch.sum(eigen_values**2)
        effective_dim = squared_sum / sum_of_squares
        return effective_dim.item()
    
    def get_effective_dim_marchenko(self, eigen_values, eigen_values_shuffled):
        import matplotlib.pyplot as plt

        # fig = plt.figure()
        # plt.plot(eigen_values)
        # plt.plot(eigen_values_shuffled)
        # fig.savefig("check.png")
        dimensionality = ((eigen_values - eigen_values_shuffled)>0).float().sum()

        return dimensionality.item()
