import torch.nn as nn
import torch
from typing import List
from nesim.utils.getting_modules import get_module_by_name
from nesim.grid.two_dimensional import BaseGrid2dLinear
from nesim.utils.grid_size import find_rectangle_dimensions
from einops import rearrange
from nesim.losses.nesim_loss import (
    NesimConfig,
    NesimLoss,
    NeighbourhoodCosineSimilarity,
)
import copy
from torchtyping import TensorType
from nesim.utils.tensor_mapping import apply_mapping, find_mapping
import matplotlib.pyplot as plt
import copy
import numpy as np


class GlobalModelWeightSorter:
    def __init__(self, layer_names: List[str], device: str):
        """
        Make sure that you provide layer_names in a proper sequence.
        The same exact sequence as seen in the forward pass definition.
        """
        self.layer_names = layer_names
        self.device = device

    def run(self, model):
        model = copy.deepcopy(model)
        layers = []

        for name in self.layer_names:
            layer = get_module_by_name(module=model, name=name)
            assert isinstance(
                layer, nn.Linear
            ), f"Expected nn.Linear but got: {type(layer)}"
            layers.append(layer)

        new_layers = []

        for layer_index in range(0, len(layers) - 1):
            first_layer, second_layer = layers[layer_index], layers[layer_index + 1]

            first_layer, second_layer = self.sort_layer_pair(
                first_layer=layers[layer_index], second_layer=layers[layer_index + 1]
            )

            new_layers.append(first_layer)
            if layer_index == len(layers) - 2:
                new_layers.append(second_layer)

        assert len(new_layers) == len(layers)

        for layer_name, layer in zip(self.layer_names, new_layers):
            # print(f"[setattr] {layer_name}")
            setattr(model, layer_name, layer)

        return model
    
    def get_hwe_grid_linear(self, layer: nn.Linear):
        size = find_rectangle_dimensions(area=layer.out_features)

        ## grid shape: (height, width, e) where e = in_features
        grid = BaseGrid2dLinear(
            linear_layer=layer,
            height=size.height,
            width=size.width,
            device=self.device,
        ).grid
        assert grid.shape[0] == size.height
        assert grid.shape[1] == size.width
        assert grid.shape[2] == layer.in_features
        return grid
    
    # def sort_hwe_grid_along_h_and_w(self, grid: TensorType):
    #     ## now lets sort grid along h and w
    #     grid_sorted_along_h, _ = torch.sort(grid, dim=0)
    #     grid_sorted_along_h_and_w, _ = torch.sort(grid_sorted_along_h, dim=1)
    #     return grid_sorted_along_h_and_w

    def sort_tensor_along_first_and_second_dim(self, tensor):
        # Get the dimensions of the input tensor
        dim0, dim1, dim2 = tensor.size()

        # Reshape the tensor to a 2D tensor for sorting along the first two dimensions
        reshaped_tensor = tensor.view(dim0 * dim1, dim2)

        # Sort the reshaped tensor along the first dimension
        a = reshaped_tensor.cpu().numpy()
        sorted_tensor = torch.tensor(
            a[np.lexsort(np.transpose(a)[::-1])]
        ).to(reshaped_tensor.device)

        # Reshape the sorted tensor back to the original shape
        sorted_tensor = sorted_tensor.view(dim0, dim1, dim2)

        return sorted_tensor

    @torch.no_grad()
    def sort_layer_pair(
        self, 
        first_layer: nn.Linear, 
        second_layer: nn.Linear
    ):
        """
        Now we will sort first layer outputs (dim = 0 in weight space)

        This function does the follwing:
        - rearrange weights of first layer along output dim
        - rearrange biases of first layer
        - rearrange weights of second layer along input dim
        """
        
        grid = self.get_hwe_grid_linear(
            layer=copy.deepcopy(first_layer)
        )
        grid_sorted_along_h_and_w = self.sort_tensor_along_first_and_second_dim(tensor=grid)

        sorted_layer_weights = rearrange(
            grid_sorted_along_h_and_w,
            "h w e -> (h w) e"
        )

        mapping_from_original_layer_to_sorted = find_mapping(
            x=first_layer.weight.data,
            y= sorted_layer_weights,
            dim = 0
        )
        first_layer.weight.data = apply_mapping(
            y=first_layer.weight.data,
            mapping=mapping_from_original_layer_to_sorted,
            dim = 0
        )
        first_layer.bias.data = apply_mapping(
            y=first_layer.bias.data,
            mapping=mapping_from_original_layer_to_sorted,
            dim=0
        )
        second_layer.weight.data = apply_mapping(
            second_layer.weight.data,
            mapping=mapping_from_original_layer_to_sorted,
            dim = 1
        )
        return first_layer, second_layer



device = "cuda:0"
in_features = 3
out_features = 4
hidden_size = 50
batch_size = 100


model = nn.Sequential(
    nn.Linear(in_features, hidden_size),  ## 0
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),  ## 2
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),  ## 4
    nn.ReLU(),
    nn.Linear(hidden_size, out_features),  ## 6
).to(device)

sorter = GlobalModelWeightSorter(
    layer_names=[
        "0",
        "2",
        "4",
        "6"
    ],
    device=device,
)
new_model = sorter.run(model=model).to(device)

x = torch.randn(batch_size, 3).to(device)
with torch.no_grad():
    y_original_model = model(x)
    y_our_model = new_model(x)

# assert torch.allclose(y_original_model, y_our_model) == True

print(
    torch.isclose(y_original_model, y_our_model, rtol=1e-05, atol=5e-08, equal_nan=False).all().item()
)

print(
    f"MSE: {torch.nn.functional.mse_loss(y_original_model, y_our_model)}"
)
print(
    f"COSSIM: {torch.nn.functional.cosine_similarity(y_original_model, y_our_model).mean()}"
)