from nesim.utils.grid_size import find_rectangle_dimensions
import torch.nn.functional as F
from einops import rearrange
import torch.nn as nn
import torch

def get_similarity_matrix(a, b, eps=1e-8):
    """
    finds the cosine similarity matrix between each item of a w.r.t each item of b
    a and b are expected to be 2 dimensional (seq, hidden_dim)
    added eps for numerical stability
    source: XXXX
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

import torch

def get_indices_above_threshold(matrix, threshold):
    # Find the indices where values are greater than the threshold
    indices = torch.nonzero(matrix > threshold, as_tuple=False)
    
    return indices


import torch
from collections import defaultdict
from tqdm import tqdm

def prune(matrix: torch.Tensor, redundant_pairs: torch.Tensor) -> tuple[torch.Tensor, dict]:
    
    new_matrix = []

    ## first, lets list down the friends of each index in the matrxidx
    ## a friend is any other row which has a high similarity
    ## matrix[i] is friends with matrix[j] if [i,j] is in redundant_pairs

    friends_of_each_row = {}
    redundant_pairs = redundant_pairs.tolist()

    for i in tqdm(range(matrix.shape[0])):
        friends_of_each_row[i] = []
        for x,y in redundant_pairs:
            if x == i:
                friends_of_each_row[i].append(y)
            elif y == i:
                friends_of_each_row[i].append(x)
    for key, value in friends_of_each_row.items():
        print(key, len(value))
    return new_matrix, mapping



# Define the inverse_prune function to restore the original matrix
def inverse_prune(pruned_matrix: torch.Tensor, redundant_indices: list[tuple[int, int]]) -> torch.Tensor:
    """
    Reconstruct the original matrix from the pruned matrix by reinserting redundant rows.
    
    Args:
        pruned_matrix (torch.Tensor): The pruned tensor.
        redundant_indices (list[tuple[int, int]]): List of tuples (redundant_row, keep_row).
    
    Returns:
        torch.Tensor: The reconstructed matrix.
    """
    # Start with the pruned matrix in list form
    restored_matrix = pruned_matrix.tolist()

    # Iterate through redundant_indices and insert the required rows
    for redundant_row, keep_row in redundant_indices:
        restored_matrix.insert(redundant_row, restored_matrix[keep_row])

    return torch.tensor(restored_matrix)


def generate_index_tensor(A, B):
    # Create grid of indices for the first and second dimensions
    i, j = torch.meshgrid(torch.arange(A), torch.arange(B), indexing='ij')
    
    # Stack the two grids along the last dimension to get shape (A, B, 2)
    index_tensor = torch.stack((i, j), dim=-1)
    
    return index_tensor

class PrunedLinear(nn.Module):
    def __init__(
        self,
        linear_layer: nn.Linear,
        cossim_threshold: float =  0.99,
        fraction_of_masked_weights = None
    ):
        super().__init__()
        weight = linear_layer.weight.data.detach()
        similarity_matrix = get_similarity_matrix(
            a=weight,
            b=weight
        )
        neuron_pair_index_tensor = generate_index_tensor(
            A=similarity_matrix.shape[0],
            B=similarity_matrix.shape[1]
        ).to(similarity_matrix.device)
        
        mask = torch.tril(
            torch.ones_like(similarity_matrix),
            diagonal=-1
        ).bool().to(similarity_matrix.device)
        similarity_values = similarity_matrix[mask]
        neuron_pair_index_tensor = neuron_pair_index_tensor[mask]
        
        assert neuron_pair_index_tensor.shape[0] == (weight.shape[0] * (weight.shape[0]-1))/2
        assert neuron_pair_index_tensor.shape[0] == (weight.shape[0] * (weight.shape[0]-1))/2

        """
        similarity_values is a list of the similarity scores between different neurons
        indices tells us which neurons the similarity scores are in between
        """
        num_all_pairs = (3072 * (3072-1))/2
        k =int(num_all_pairs * fraction_of_masked_weights)
        similarity_values_sorted, similarity_values_sorted_indices = torch.topk(
            similarity_values,
            k=k
        )
        self.redundant_neuron_pairs = neuron_pair_index_tensor[similarity_values_sorted_indices, :]
        print(f"redundant_neuron_pairs.shape: {self.redundant_neuron_pairs.shape} k: {k} num all pairs: {num_all_pairs}")

        pruned_weight, self.prune_mapping = prune(
            matrix=weight,
            redundant_pairs=self.redundant_neuron_pairs
        )
        self.weight = nn.Parameter(pruned_weight)
        self.bias = nn.Parameter(linear_layer.bias.data)
        self.actual_num_output_neurons = weight.shape[0]
        self.fraction_of_masked_weights = ((weight.numel() - self.weight.numel())/weight.numel())

        print(f"fraction_of_masked_weights: {self.fraction_of_masked_weights} k: {k} redundant_neuron_pairs: {self.redundant_neuron_pairs.shape} pruned_weight = {self.weight.shape}")
        # print(f"weight.shape: {weight.shape} new weight: {self.weight.shape}")
        

    def deprune_layer_output(self, y_before_bias):
        assert y_before_bias.shape[-1] == self.weight.shape[0]

        if y_before_bias.ndim == 3:
            outputs = torch.zeros(y_before_bias.shape[0], y_before_bias.shape[1], self.actual_num_output_neurons)
        else:
            outputs = torch.zeros(
                y_before_bias.shape[0],
                self.actual_num_output_neurons
            )

        outputs = outputs.to(y_before_bias.device)
        for idx in self.prune_mapping:
            outputs[
                :,:, self.prune_mapping[idx]
            ] = y_before_bias[:,:, idx].reshape(1,-1,1)

        return outputs
    
    def forward(self, x):
        y_before_bias =  nn.functional.linear(
            input=x,
            weight=self.weight,
            bias=None
        )
        actual_output = self.deprune_layer_output(
            y_before_bias=y_before_bias
        )
        assert actual_output.shape[-1] == self.actual_num_output_neurons
        y = actual_output + self.bias
        return y


class DownsampledLinear(nn.Module):
    def __init__(
        self, 
        linear_layer: nn.Linear, 
        factor_h: int = 2, 
        factor_w: int = 2, 
        device: str = "cuda:0"
    ):
        super().__init__()
        """
        The total parameter count shrinkage is roughly equal to: factor_h * factor_w
        """
        
        ## weight.shape: output, input
        weight = linear_layer.weight.data.detach()
        size = find_rectangle_dimensions(area=weight.shape[0])
        grid = weight.reshape(size.height, size.width, weight.shape[1])
        grid = rearrange(grid, "h w e -> e h w").unsqueeze(0)

        downsampled_grid = F.interpolate(
            grid, scale_factor=(1 / factor_h, 1 / factor_w), mode="nearest-exact"
        ).squeeze(0)

        self.downsampled_weight = nn.Parameter(
            rearrange(downsampled_grid, "e h w -> (h w) e").to(device)
        )

        self.small_grid_size = find_rectangle_dimensions(area=self.downsampled_weight.shape[0])
        self.num_output_neurons = linear_layer.weight.shape[0]
        self.bias = nn.Parameter(linear_layer.bias.detach().to(device))
        self.size = size
        self.factor_h = factor_h
        self.factor_w = factor_w

    def forward_compressed(self, x):
        y_before_bias = torch.nn.functional.linear(
            input=x, weight=self.downsampled_weight, bias=None
        )
        # raise AssertionError(
        #     y_before_bias.shape,
        #     self.small_grid_size,
        # )
        # raise AssertionError(self.downsampled_weight.shape, self.small_grid_size)
        
        try:
            ## batch, smol -> batch, 1, small_h, small_ws
            y_before_bias = y_before_bias.reshape(
                y_before_bias.shape[0], self.small_grid_size.height, self.small_grid_size.width
            ).unsqueeze(1)
        except RuntimeError:
            ## batch, seq, smol -> batch, seq, small_h, small_w
            y_before_bias = y_before_bias.reshape(
                y_before_bias.shape[0], y_before_bias.shape[1], self.small_grid_size.height, self.small_grid_size.width
            )

        ## batch, 1, small_h, small_w -> batch, 1, h , w
        y_before_bias_upsampled = F.interpolate(
            y_before_bias, size=(self.size.height, self.size.width), mode="bilinear"
        )
        y_before_bias_upsampled = rearrange(
            y_before_bias_upsampled, "b seq h w -> b seq (h w)"
        )

        # assert y_before_bias_upsampled.shape == (
        #     x.shape[0],
        #     self.num_output_neurons,
        # ), f"Invalid shape: {y_before_bias_upsampled.shape}"

        y_after_bias = y_before_bias_upsampled + self.bias

        return y_after_bias


    def forward(self, x):
        return self.forward_compressed(x)

    def __repr__(self):
        return f"DownsampledLinear(in_features={self.downsampled_weight.shape[1]}, out_features={self.num_output_neurons}, factor_h={self.factor_h}, factor_w={self.factor_h})"

    @property
    def weight(self):
        return self.downsampled_weight