from collections import OrderedDict

import torch
import torch.nn.functional as F

from experiments.weight_matching import weight_matching, mlp_permutation_spec


def random_permute(weights1, biases1, weights2, biases2):
    batch_size = weights1[0].shape[0]

    batch_permuted_weights = []
    for i in range(batch_size):
        formatted_weights2 = OrderedDict(
            [
                ('seq.0.weight', weights2[0][i].squeeze(-1).T),
                ('seq.0.bias', biases2[0][i].squeeze(-1)),
                ('seq.1.weight', weights2[1][i].squeeze(-1).T),
                ('seq.1.bias', biases2[1][i].squeeze(-1)),
                ('seq.2.weight', weights2[2][i].squeeze(-1).T),
                ('seq.2.bias', biases2[2][i].squeeze(-1)),
            ]
        )
        weight_sizes = [v.shape for k, v in formatted_weights2.items() if 'weight' in k]
        random_perm = [torch.randperm(size[0]) for size in weight_sizes][:-1]
        random_permuted_weights = permute_state_dict(formatted_weights2, random_perm)
        batch_permuted_weights.append(random_permuted_weights)

    # We need to transform the permuted weights back into the original format.
    permuted_weights = tuple(
        torch.stack(
            [
                item[f'seq.{i}.weight'].T.unsqueeze(-1)
                for item in batch_permuted_weights
            ],
            dim=0
        )
        for i in range(3)
    )
    permuted_biases = tuple(
        torch.stack(
            [
                item[f'seq.{i}.bias'].unsqueeze(-1)
                for item in batch_permuted_weights
            ],
            dim=0
        )
        for i in range(3)
    )
    return permuted_weights, permuted_biases


def optimal_permute(weights1, biases1, weights2, biases2):
    """Given two sets of weights and biases, find the permutation of the
    find the permutation of the weights and biases of the second set that
    minimizes the distance between the two. The permutation is found by
    approximate optimal transport from Git Re-basin [1];
    see https://github.com/themrzmaster/git-re-basin-pytorch.

    NOTE: The NN architecture is hardcoded to be a 3-layer MLP for now.

    References:
    [1] Ainsworth, Samuel K., Jonathan Hayase, and Siddhartha Srinivasa.
    "Git re-basin: Merging models modulo permutation symmetries".
    In: ICLR 2022.

    Notes:
    1. The optimal transport is performed for each batch item separately.
    2. We also have to transform the weights and biases into the desirable
    format, which is an OrderedDict of tensors. The ordering of the tensors in
    the OrderedDict is:
        a. weights of the first layer
        b. biases of the first layer
        c. weights of the second layer
        d. biases of the second layer
        f. ...
    3. The weights and biases in the OrderedDict are of shape (out_channels,
    in_channels) and (out_channels), respectively.
    4. The keys of the OrderedDict are of the form 'seq.{i}.weight' and
    'seq.{i}.bias', where i is the index of the layer.

    Args:
    - weights1: tuple of tensors of shape (batch_size, in_channels, out_channels, 1)
    - biases1: tuple of tensors of shape (batch_size, out_channels, 1)
    - weights2: tuple of tensors of shape (batch_size, in_channels, out_channels, 1)
    - biases2: tuple of tensors of shape (batch_size, out_channels, 1)

    Returns:
    - permuted_weights: tuple of tensors of shape (batch_size, in_channels, out_channels, 1)
    - permuted_biases: tuple of tensors of shape (batch_size, out_channels, 1)
    """
    batch_size = weights1[0].shape[0]

    batch_permuted_weights = []
    for i in range(batch_size):
        formatted_weights1 = OrderedDict(
            [
                ('seq.0.weight', weights1[0][i].squeeze(-1).T),
                ('seq.0.bias', biases1[0][i].squeeze(-1)),
                ('seq.1.weight', weights1[1][i].squeeze(-1).T),
                ('seq.1.bias', biases1[1][i].squeeze(-1)),
                ('seq.2.weight', weights1[2][i].squeeze(-1).T),
                ('seq.2.bias', biases1[2][i].squeeze(-1)),
            ]
        )
        formatted_weights2 = OrderedDict(
            [
                ('seq.0.weight', weights2[0][i].squeeze(-1).T),
                ('seq.0.bias', biases2[0][i].squeeze(-1)),
                ('seq.1.weight', weights2[1][i].squeeze(-1).T),
                ('seq.1.bias', biases2[1][i].squeeze(-1)),
                ('seq.2.weight', weights2[2][i].squeeze(-1).T),
                ('seq.2.bias', biases2[2][i].squeeze(-1)),
            ]
        )
        optimal_perm = weight_matching(
            mlp_permutation_spec(2), formatted_weights1, formatted_weights2
        )
        optimal_perm = [v.long() for v in optimal_perm.values()]
        optimal_permuted_weights = permute_state_dict(formatted_weights2, optimal_perm)
        batch_permuted_weights.append(optimal_permuted_weights)

    # We need to transform the permuted weights back into the original format.
    permuted_weights = tuple(
        torch.stack(
            [
                item[f'seq.{i}.weight'].T.unsqueeze(-1)
                for item in batch_permuted_weights
            ],
            dim=0
        )
        for i in range(3)
    )
    permuted_biases = tuple(
        torch.stack(
            [
                item[f'seq.{i}.bias'].unsqueeze(-1)
                for item in batch_permuted_weights
            ],
            dim=0
        )
        for i in range(3)
    )
    return permuted_weights, permuted_biases


def mixup_weights_labels(
    batch, mixup_dist, device, num_classes, optimal_perm=False,
    random_perm=False
):
    batch1, batch2 = batch
    lam = mixup_dist.sample((batch1.label.shape[0],))

    perm_weights, perm_biases = (
        optimal_permute(
            batch1.weights, batch1.biases, batch2.weights, batch2.biases)
        if optimal_perm
        else random_permute(
            batch1.weights, batch1.biases, batch2.weights, batch2.biases)
        if random_perm else (batch2.weights, batch2.biases)
    )
    perm_weights = tuple(w.to(device) for w in perm_weights)
    perm_biases = tuple(b.to(device) for b in perm_biases)

    batch1 = batch1.to(device)
    batch2 = batch2.to(device)
    label = (lam * F.one_hot(batch1.label, num_classes=num_classes)
             + (1 - lam) * F.one_hot(batch2.label, num_classes=num_classes))

    weights = [
        lam[:, None, None, :] * weights1 + (1 - lam[:, None, None, :]) * weights2
        for weights1, weights2 in zip(batch1.weights, perm_weights)
    ]
    biases = [
        lam[:, None, :] * biases1 + (1 - lam[:, None, :]) * biases2
        for biases1, biases2 in zip(batch1.biases, perm_biases)
    ]

    return (weights, biases), label


def permute_dim(x: torch.Tensor, dim: int, perm: torch.LongTensor) -> torch.Tensor:
    """Permute a dimension of a tensor with the permutation vector `perm`.

    Args:
        x: n-dimensional Tensor of shape (dim_1, ..., dim_n).
        dim: Dimension to permute.
        perm: Permutation vector of shape (dim_i,), where dim_i = `dim`.

    Returns:
        Tensor of shape (dim_1, ..., dim_n).
    """
    return x.index_select(dim, perm)


def permute_state_dict(
    state_dict: OrderedDict[str, torch.Tensor],
    permutations: list[torch.LongTensor]
) -> OrderedDict[str, torch.Tensor]:
    """Permute the weights of a neural network with the permutation vectors `permutations`.

    Each permutation acts on a pair of layers, and permutes the rows of the
    preceding layer and the columns of the following layer. For biases, the
    permutation acts only on the preceding layer.

    The keys of the state dict are assumed to be of the form 'seq.l.weight' or
    'seq.l.bias', where 'l' is the layer index, starting from 0.

    Args:
        state_dict: L-layer neural network weights and biases, stored in a state
        dict.
        permutations: List of permutation vectors, one for each pair of consecutive
        layers in the neural network, size `L-1`.

    Returns: A state dict with the weights permuted.
    """
    num_layers = max([int(key[4]) for key in state_dict]) + 1
    if len(permutations) != num_layers - 1:
        raise ValueError('Number of permutations must equal number of layers - 1.')

    # Test that the permutation vectors are valid.
    # for i, p in enumerate(permutations):
        # if (p.shape[0] != state_dict[f'net.{i}.linear.weight'].shape[0]
            # or p.shape[0] != state_dict[f'net.{i+1}.linear.weight'].shape[1]):

            # raise ValueError(f'Permutation {i} has invalid shape.')

    new_state_dict = OrderedDict()
    for key, params in state_dict.items():
        layer_index = int(key[4])
        # print(params.shape, permutations[layer_index].shape)
        if 'weight' in key:
            if layer_index == 0:
                new_state_dict[key] = permute_dim(params, 0, permutations[0])
            elif layer_index == len(permutations):
                new_state_dict[key] = permute_dim(params, 1, permutations[-1])
            else:
                new_state_dict[key] = permute_dim(
                    permute_dim(params, 0, permutations[layer_index]),
                    1, permutations[layer_index - 1])
        else:
            if layer_index == num_layers - 1:
                new_state_dict[key] = params
            else:
                new_state_dict[key] = permute_dim(params, 0, permutations[layer_index])
    return new_state_dict
