import matplotlib.pyplot as plt
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
from nesim.experiments.mnist import get_untrained_model
import torch
import torch.nn as nn
import os


from nesim.grid.one_dimensional import BaseGrid1dLinear


def get_file_size_bytes(filename):
    file_stats = os.stat(filename)
    return file_stats.st_size


def keep_top_p(tensor, p=0.1):
    """
    Keep only the top percentage of elements in the 2D tensor along the last dimension,
    setting the rest to 0.

    Parameters:
    - tensor (torch.Tensor): Input tensor (2D).
    - p (float): Percentage of top values to keep (default is 0.1 for 10%).

    Returns:
    - torch.Tensor: Modified tensor with only the top percentage of values kept along the last dimension.
    """
    if tensor.numel() == 0:
        return tensor  # Return empty tensor if input tensor is empty

    num_elements_to_keep = max(1, int(tensor.size(-1) * p))
    threshold_value, _ = torch.topk(
        tensor.abs(), num_elements_to_keep, dim=-1, largest=True
    )
    result_tensor = torch.where(
        tensor.abs() >= threshold_value.min(),
        tensor,
        torch.tensor(0.0, dtype=tensor.dtype),
    )

    return result_tensor


def dense_to_sparse_2d(input_tensor):
    # Find indices of non-zero elements
    indices = torch.nonzero(input_tensor, as_tuple=False)

    # Extract non-zero values
    values = input_tensor[indices[:, 0], indices[:, 1]]

    # Create a sparse tensor
    sparse_tensor = torch.sparse_coo_tensor(
        indices.t(), values, input_tensor.size()
    )  # or use torch.sparse.FloatTensor

    return sparse_tensor


def sparse_to_dense_2d(sparse_tensor):
    # Create a dense tensor with zeros
    dense_tensor = sparse_tensor.to_dense()

    return dense_tensor


def load_compressed_linear_layer(filename, bias=True):
    assert os.path.exists(filename), f"Invalid filename: {filename}"
    state_dict = torch.load(filename)

    linear_layer = nn.Linear(
        in_features=state_dict["in_features"],
        out_features=state_dict["out_features"],
        bias=bias,
    )

    if bias is True:
        linear_layer.bias.data = state_dict["bias"]

    linear_layer.weight.data = torch.fft.ifft(
        sparse_to_dense_2d(state_dict["sparse_fft_signal"])
    ).real
    return linear_layer


class CompressedLinear:
    @torch.no_grad()
    def __init__(self, linear_layer: nn.Linear, fft_top_p: float = 1e-3) -> None:
        self.out_features, self.in_features = linear_layer.weight.shape
        one_dimensional_grid_container = BaseGrid1dLinear(linear_layer=linear_layer)
        fft_signal = torch.fft.fft(one_dimensional_grid_container.grid)
        self.thresholded_fft_signal = keep_top_p(tensor=fft_signal, p=fft_top_p)
        self.bias = linear_layer.bias.data

        self.sparse_fft_signal = dense_to_sparse_2d(
            input_tensor=self.thresholded_fft_signal
        )

        print(
            self.evaluate_compression_quality(
                fft_signal=fft_signal,
                original_weights_transposed=linear_layer.weight.data,
            )
        )

    def state_dict(self):
        state_dict = {
            "sparse_fft_signal": self.sparse_fft_signal,
            "in_features": self.in_features,
            "out_features": self.out_features,
            "bias": self.bias,
        }
        return state_dict

    def save(self, filename: str):
        torch.save(self.state_dict(), filename)

    def validate_state_dict(self, state_dict):
        pass

    @torch.no_grad()
    def evaluate_compression_quality(self, fft_signal, original_weights_transposed):
        cosine_similarity_real = (
            nn.functional.cosine_similarity(
                torch.fft.ifft(fft_signal).real, original_weights_transposed.real
            )
            .mean()
            .item()
        )

        mse_loss_real = (
            nn.functional.mse_loss(
                torch.fft.ifft(fft_signal).real, original_weights_transposed.real
            )
            .mean()
            .item()
        )

        return {
            "cosine_similarity_real": cosine_similarity_real,
            "mse_loss_real": mse_loss_real,
        }


filename = "./checkpoints/mnist/bimt_mnist_apply_ring_loss_1d_scale_1.0_radius_0_5_hidden_size_1024_apply_every_1_steps/best/mnist-epoch=02-val_loss=0.088-val_acc=0.974.ckpt"
model = get_untrained_model(hidden_size=1024)
model.load_state_dict(load_and_filter_state_dict_keys(filename))

# plt.plot(model[4].weight.data[:, 3])
# plt.plot(model[4].weight.data[:, 4])

# plt.show()
# plt.plot(torch.fft.fft(model[4].weight.data[:, 3]))
# plt.show()

torch.save(model[4], "original.pth")

compressed_linear = CompressedLinear(linear_layer=model[4], fft_top_p=0.005)

compressed_linear.save(filename="compressed.pth")

original_linear_layer = model[4]
loaded_compressed_linear_layer = load_compressed_linear_layer(
    filename="compressed.pth", bias=True
)

print(
    f"""COSSIM: {nn.functional.cosine_similarity(
    original_linear_layer.weight,
    loaded_compressed_linear_layer.weight
).mean()}"""
)

print(
    f"""MSE: {nn.functional.mse_loss(
    original_linear_layer.weight,
    loaded_compressed_linear_layer.weight
).mean()}"""
)

compressed_size = get_file_size_bytes("compressed.pth")
original_size = get_file_size_bytes("original.pth")
print(
    f"Compression ratio: {round(compressed_size/original_size, 8) * 100}% of original size"
)

model[4] = loaded_compressed_linear_layer
torch.save(model.state_dict(), "reconstructed.pth")
