import copy
from collections import defaultdict
from typing import Dict, List, NamedTuple, Tuple

import hydra
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from omegaconf import DictConfig
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm

from models.mlp import ReLUMLP, SigmoidMLP
from utils import (find_best_permutation, initialize_and_load_model,
                   seed_everything)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.set_default_dtype(torch.float64)

strategy = "wm"


class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict


def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)


def mlp_permutation_spec(num_hidden_layers: int) -> PermutationSpec:
    """We assume that one permutation cannot appear in two axes of the same weight array."""
    assert num_hidden_layers >= 1
    return permutation_spec_from_axes_to_perm(
        {
            "layers.0.weight": ("P_0", None),
            **{
                f"layers.{i}.weight": (f"P_{i}", f"P_{i-1}")
                for i in range(1, num_hidden_layers)
            },
            **{f"layers.{i}.bias": (f"P_{i}",) for i in range(num_hidden_layers)},
            f"layers.{num_hidden_layers}.weight": (None, f"P_{num_hidden_layers-1}"),
            f"layers.{num_hidden_layers}.bias": (None,),
        }
    )


def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    w = params[k]
    for axis, p in enumerate(ps.axes_to_perm[k]):
        if axis == except_axis:
            continue

        if p is not None:
            w = torch.index_select(w, axis, perm[p].int())

    return w


def apply_permutation(ps: PermutationSpec, perm, params):
    return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}


def weight_matching(
    ps: PermutationSpec, params_a, params_b, max_iter=100, init_perm=None
):
    perm_sizes = {
        p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()
    }

    perm = (
        {p: torch.arange(n) for p, n in perm_sizes.items()}
        if init_perm is None
        else init_perm
    )
    perm_names = list(perm.keys())

    for iteration in range(max_iter):
        progress = False
        for p_ix in torch.randperm(len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = torch.zeros((n, n)).to(device)
            for wk, axis in ps.perm_to_axes[p]:
                w_a = get_permuted_param(ps, perm, wk, params_a, except_axis=axis)
                w_b = params_b[wk]
                w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1))
                w_a, w_b = w_a.to(device), w_b.to(device)
                A += w_a @ w_b.T

            ri, ci = linear_sum_assignment(A.detach().cpu().numpy().T, maximize=True)
            assert (torch.tensor(ri) == torch.arange(len(ri))).all()
            A = A.to("cpu")

            oldL = sum(A[perm[p], ri])
            newL = sum(A[ci, ri])
            logger.info(
                f"{iteration}/{p}: current distance={newL}, improvement={newL-oldL}"
            )
            progress = progress or newL > oldL + 1e-12

            perm[p] = torch.tensor(ci, dtype=int)
        if not progress:
            break

    return perm


@hydra.main(config_path="./config/", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    seed_everything(cfg.seed_a)
    for mode in ("relu_mlp", "sigmoid_mlp"):
        logger.info(f"Mode: {mode}")

        model_a = initialize_and_load_model(
            f"./output/model/{mode}_{cfg.model_a}",
            cfg.dataset.input_dim,
            cfg.dataset.output_dim,
            cfg.depth,
            cfg.alpha,
            cfg.n_tree,
            mode,
        ).to("cpu")
        model_b = initialize_and_load_model(
            f"./output/model/{mode}_{cfg.model_b}",
            cfg.dataset.input_dim,
            cfg.dataset.output_dim,
            cfg.depth,
            cfg.alpha,
            cfg.n_tree,
            mode,
        ).to("cpu")

        ps = mlp_permutation_spec(cfg.depth)
        if cfg.seed_a == cfg.seed_b:
            perm_sizes = {
                p: model_a.state_dict()[axes[0][0]].shape[axes[0][1]]
                for p, axes in ps.perm_to_axes.items()
            }

            perm = {p: torch.randperm(n) for p, n in perm_sizes.items()}
            model_a.load_state_dict(apply_permutation(ps, perm, model_a.state_dict()))

            logger.info("Same seeds are used. model_a is randomly shuffled.")

        final_permutation = weight_matching(
            ps, model_a.state_dict(), model_b.state_dict()
        )
        updated_params = apply_permutation(ps, final_permutation, model_a.state_dict())
        model_path = f"./output/model/{strategy}_permuted_{mode}_{cfg.model_a}"
        torch.save(updated_params, model_path)
        logger.info(f"Model Saved: {model_path}")


if __name__ == "__main__":
    main()
