import copy
import random
from collections import defaultdict
from itertools import permutations, product
from typing import Dict, List, Tuple

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from omegaconf import DictConfig
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from dataset import get_dataloader
from utils import (extract_parameters, find_best_permutation, load_tree_models,
                   patterns_to_indices, random_shuffle_decision_list,
                   random_shuffle_non_oblivious, random_shuffle_oblivious,
                   save_tree_model, seed_everything)

torch.set_default_dtype(torch.float64)
torch.set_num_threads(1)

strategy = "am"


@hydra.main(config_path="./config/", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    seed_everything(cfg.seed_a)
    for mode in (
        "non_oblivious",
        "decision_list",
        "decision_list_no_binary",
        "oblivious",
    ):
        if (cfg.depth > 3) and mode == "non_oblivious":
            logger.info(f"depth={cfg.depth}. {mode} is skipped.")
            continue

        logger.info(f"Mode: {mode}")

        model_a, model_a_treewise, model_b, model_b_treewise = load_tree_models(
            cfg, mode, cpu=True
        )

        if cfg.depth <= 3:
            n_nodes = cfg.depth if mode == "oblivious" else 2**cfg.depth - 1
            all_patterns = list(product([False, True], repeat=n_nodes))
            all_ordering = list(permutations(np.arange(cfg.depth)))
            flip_indices_array = patterns_to_indices(all_patterns)

        if cfg.seed_a == cfg.seed_b:
            if mode == "non_oblivious":
                random_shuffle_non_oblivious(model_a_treewise, cfg)
            elif mode == "oblivious":
                random_shuffle_oblivious(model_a_treewise, cfg)
            else:
                random_shuffle_decision_list(model_a_treewise, cfg)

            logger.info("Same seeds are used. model_a is randomly shuffled.")

        train_loader, test_loader = get_dataloader(
            cfg.dataset.name, cfg.batch_size, cpu=True
        )
        random_indices = torch.randperm(len(test_loader.dataset))[: cfg.batch_size]

        subset = Subset(test_loader.dataset, random_indices)
        subset_loader = DataLoader(subset, batch_size=cfg.batch_size)

        with torch.no_grad():
            for data, target in tqdm(subset_loader, leave=False):
                output_a = model_a_treewise.forward(data, treewise=True)
                output_b = model_b_treewise.forward(data, treewise=True)
        logger.info("Inference is completed")

        output_a = output_a.transpose(0, 2).reshape(-1, cfg.n_tree)
        output_b = output_b.transpose(0, 2).reshape(-1, cfg.n_tree)
        distance_matrix = output_a.T @ output_b
        logger.info("Distance matrix is calculated")

        row_ind, col_ind = find_best_permutation(distance_matrix.t())

        logger.info(f"Permutation (without sign-flip): {col_ind}")

        if mode in ("non_oblivious", "decision_list", "decision_list_no_binary"):
            model_a_treewise_noflip = copy.deepcopy(model_a_treewise)
            model_a_treewise_noflip.apply_permutation(col_ind)
            save_tree_model(
                cfg,
                model_a_treewise_noflip,
                model_path=f"./output/model/{strategy}_permuted_noop_{mode}_{cfg.model_a}",
            )
            if cfg.depth > 3:
                logger.info(
                    f"depth={cfg.depth}. Additional invariances are not considered."
                )
                continue

            if mode == "non_oblivious":  # with flip
                model_a_treewise_flip = copy.deepcopy(model_a_treewise)
                model_a_treewise_flip.apply_permutation(col_ind)

                with torch.no_grad():
                    resulting_flips = []
                    for tree_index in tqdm(
                        range(cfg.n_tree), leave=False, desc="Searching..."
                    ):
                        best_score = -1e10
                        best_flips = None
                        tree_b = copy.deepcopy(model_b_treewise.trees[tree_index])
                        tree_b.collect_parameters()
                        params_b = extract_parameters(tree_b, weighting=True)
                        vector_b = torch.cat(
                            (
                                params_b["weight"].view(-1),
                                params_b["bias"].view(-1),
                                params_b["leaf"].view(-1),
                            )
                        )
                        for flip_indices in flip_indices_array:
                            tree_a = copy.deepcopy(
                                model_a_treewise_flip.trees[tree_index]
                            )
                            for flip_node_index in flip_indices:
                                tree_a.flip_children(flip_node_index)
                            # assert torch.isclose(tree_b.forward(data), tree_a.forward(data)).all()
                            params_a = extract_parameters(tree_a, weighting=True)
                            vector_a = torch.cat(
                                (
                                    params_a["weight"].view(-1),
                                    params_a["bias"].view(-1),
                                    params_a["leaf"].view(-1),
                                )
                            )
                            score = vector_a @ vector_b
                            if score > best_score:
                                best_score = score
                                best_flips = flip_indices

                        resulting_flips.append(best_flips)

                    for tree_index, resulting_flip in enumerate(
                        tqdm(resulting_flips, leave=False, desc="Fixing...")
                    ):
                        tree_a = model_a_treewise_flip.trees[tree_index]
                        for flip_node_index in resulting_flip:
                            tree_a.flip_children(flip_node_index)

                save_tree_model(
                    cfg,
                    model_a_treewise_flip,
                    model_path=f"./output/model/{strategy}_permuted_flip_non_oblivious_{cfg.model_a}",
                )
        elif mode == "oblivious":
            model_a_treewise_noop = copy.deepcopy(model_a_treewise)
            model_a_treewise_noop.apply_permutation(col_ind)
            save_tree_model(
                cfg,
                model_a_treewise_noop,
                model_path=f"./output/model/{strategy}_permuted_noop_oblivious_{cfg.model_a}",
            )
            if cfg.depth > 3:
                logger.info(
                    f"depth={cfg.depth}. Additional invariances are not considered."
                )
                continue

            # only with flip
            model_a_treewise_onlyflip = copy.deepcopy(model_a_treewise)
            model_a_treewise_onlyflip.apply_permutation(col_ind)

            with torch.no_grad():
                resulting_flips = []
                for tree_index in tqdm(
                    range(cfg.n_tree), leave=False, desc="Searching..."
                ):
                    best_score = -1e10
                    best_flips = None
                    tree_b = copy.deepcopy(model_b_treewise.trees[tree_index])
                    params_b = extract_parameters(tree_b, weighting=True)
                    vector_b = torch.cat(
                        (
                            params_b["weight"].view(-1),
                            params_b["bias"].view(-1),
                            params_b["leaf"].view(-1),
                        )
                    )
                    for flip_indices in flip_indices_array:
                        tree_a = copy.deepcopy(
                            model_a_treewise_onlyflip.trees[tree_index]
                        )
                        for flip_node_index in flip_indices:
                            tree_a.flip_children(flip_node_index + 1)  # 1-index
                        # assert torch.isclose(tree_b.forward(data), tree_a.forward(data)).all()
                        params_a = extract_parameters(tree_a, weighting=True)
                        vector_a = torch.cat(
                            (
                                params_a["weight"].view(-1),
                                params_a["bias"].view(-1),
                                params_a["leaf"].view(-1),
                            )
                        )
                        score = vector_a @ vector_b
                        if score > best_score:
                            best_score = score
                            best_flips = flip_indices

                    resulting_flips.append(best_flips)

                for tree_index, resulting_flip in enumerate(
                    tqdm(resulting_flips, leave=False, desc="Fixing...")
                ):
                    tree_a = model_a_treewise_onlyflip.trees[tree_index]
                    for flip_node_index in resulting_flip:
                        tree_a.flip_children(flip_node_index + 1)  # 1-index

            save_tree_model(
                cfg,
                model_a_treewise_onlyflip,
                model_path=f"./output/model/{strategy}_permuted_onlyflip_oblivious_{cfg.model_a}",
            )

            # only with ordering
            model_a_treewise_onlyorder = copy.deepcopy(model_a_treewise)
            model_a_treewise_onlyorder.apply_permutation(col_ind)

            with torch.no_grad():
                resulting_orders = []
                for tree_index in tqdm(
                    range(cfg.n_tree), leave=False, desc="Searching..."
                ):
                    best_score = -1e10
                    best_flips = None
                    tree_b = copy.deepcopy(model_b_treewise.trees[tree_index])
                    params_b = torch.cat([p.view(-1) for p in tree_b.parameters()])
                    for order in tqdm(all_ordering, leave=False):
                        tree_a = copy.deepcopy(
                            model_a_treewise_onlyorder.trees[tree_index]
                        )
                        tree_a.apply_reordering(order)
                        # assert torch.isclose(tree_b.forward(data), tree_a.forward(data)).all()
                        params_a = extract_parameters(tree_a, weighting=True)
                        vector_a = torch.cat(
                            (
                                params_a["weight"].view(-1),
                                params_a["bias"].view(-1),
                                params_a["leaf"].view(-1),
                            )
                        )
                        score = vector_a @ vector_b
                        if score > best_score:
                            best_score = score
                            best_orders = order

                    resulting_orders.append(best_orders)

                for tree_index_a, resulting_order in enumerate(resulting_orders):
                    tree_a = model_a_treewise_onlyorder.trees[tree_index_a]
                    tree_a.apply_reordering(resulting_order)

            save_tree_model(
                cfg,
                model_a_treewise_onlyorder,
                model_path=f"./output/model/{strategy}_permuted_onlyorder_oblivious_{cfg.model_a}",
            )

            # flip and reordering
            model_a_treewise_orderflip = copy.deepcopy(model_a_treewise)
            model_a_treewise_orderflip.apply_permutation(col_ind)

            with torch.no_grad():
                resulting_orders = []
                resulting_flips = []
                for tree_index in tqdm(
                    range(cfg.n_tree), leave=False, desc="Searching..."
                ):
                    best_score = -1e10
                    best_flips = None
                    best_orders = None
                    tree_b = copy.deepcopy(model_b_treewise.trees[tree_index])
                    params_b = extract_parameters(tree_b, weighting=True)
                    vector_b = torch.cat(
                        (
                            params_b["weight"].view(-1),
                            params_b["bias"].view(-1),
                            params_b["leaf"].view(-1),
                        )
                    )

                    for order in tqdm(all_ordering, leave=False):
                        for flip_indices in flip_indices_array:
                            tree_a = copy.deepcopy(
                                model_a_treewise_orderflip.trees[tree_index]
                            )
                            tree_a.apply_reordering(order)
                            for flip_node_index in flip_indices:
                                tree_a.flip_children(flip_node_index + 1)  # 1-index

                            # assert torch.isclose(tree_b.forward(data), tree_a.forward(data)).all()
                            params_a = extract_parameters(tree_a, weighting=True)
                            vector_a = torch.cat(
                                (
                                    params_a["weight"].view(-1),
                                    params_a["bias"].view(-1),
                                    params_a["leaf"].view(-1),
                                )
                            )
                            score = vector_a @ vector_b
                            if score > best_score:
                                best_score = score
                                best_orders = order
                                best_flips = flip_indices

                    resulting_flips.append(best_flips)
                    resulting_orders.append(best_orders)

                for tree_index_a, (resulting_order, resulting_flip) in enumerate(
                    zip(resulting_orders, resulting_flips)
                ):
                    tree_a = model_a_treewise_orderflip.trees[tree_index_a]
                    tree_a.apply_reordering(resulting_order)
                    for flip_node_index in resulting_flip:
                        tree_a.flip_children(flip_node_index + 1)  # 1-index

            save_tree_model(
                cfg,
                model_a_treewise_orderflip,
                model_path=f"./output/model/{strategy}_permuted_orderflip_oblivious_{cfg.model_a}",
            )
        else:
            raise NotImplementedError


if __name__ == "__main__":
    main()
