import copy
import math
import random
from itertools import permutations, product
from typing import Dict, List, Tuple

import hydra
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from omegaconf import DictConfig
from tqdm import tqdm

from models.non_oblivious import SoftTreeEnsemble, TreeWiseSoftTreeEnsemble
from models.oblivious import (ObliviousTreeEnsemble,
                              TreeWiseObliviousTreeEnsemble)
from utils import (extract_parameters, find_best_permutation, load_tree_models,
                   patterns_to_indices, random_shuffle_decision_list,
                   random_shuffle_non_oblivious, random_shuffle_oblivious,
                   save_tree_model, seed_everything)

torch.set_default_dtype(torch.float64)
torch.set_num_threads(1)

strategy = "wm"


@hydra.main(config_path="./config/", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    seed_everything(cfg.seed_a)
    for mode in (
        "non_oblivious",
        "decision_list",
        "decision_list_no_binary",
        "oblivious",
    ):
        if (cfg.depth > 3) and mode == "non_oblivious":
            logger.info(f"depth={cfg.depth}. {mode} is skipped.")
            continue

        logger.info(f"Mode: {mode}")

        model_a, model_a_treewise, model_b, model_b_treewise = load_tree_models(
            cfg, mode, cpu=True
        )
        model_a_clone = copy.deepcopy(model_a)

        if (cfg.depth > 3) or (mode in ("decision_list", "decision_list_no_binary")):
            n_nodes = 0
            flip_indices_array = [[]]
        else:
            n_nodes = cfg.depth if mode == "oblivious" else 2**cfg.depth - 1
            all_patterns = list(product([False, True], repeat=n_nodes))
            flip_indices_array = patterns_to_indices(all_patterns)
        logger.info(f"flip_indices_array: {flip_indices_array}")

        if cfg.seed_a == cfg.seed_b:
            if mode == "non_oblivious":
                random_shuffle_non_oblivious(model_a_treewise, cfg)
            elif mode == "oblivious":
                random_shuffle_oblivious(model_a_treewise, cfg)
            else:
                random_shuffle_decision_list(model_a_treewise, cfg)

            logger.info("Same seeds are used. model_a is randomly shuffled.")

        if mode in ("non_oblivious", "decision_list", "decision_list_no_binary"):
            params_b = extract_parameters(model_b, weighting=True)
            vector_b = torch.cat(
                (
                    params_b["weight"].transpose(1, 2).reshape(-1, cfg.n_tree),
                    params_b["bias"].view(-1, cfg.n_tree),
                    params_b["leaf"].view(-1, cfg.n_tree),
                )
            )

            distance_matrices = torch.zeros((2**n_nodes, cfg.n_tree, cfg.n_tree))
            for i, flip_indices in enumerate(tqdm(flip_indices_array)):
                model_a_treewise_clone = copy.deepcopy(model_a_treewise)
                for flip_node_index in flip_indices:
                    if mode == "non_oblivious":  # non_obliviousの場合だけフリップを行う
                        model_a_treewise_clone.flip_children(
                            torch.tensor([-1] * cfg.n_tree), flip_node_index
                        )
                model_a_clone.copy_parameters(model_a_treewise_clone)
                params_a = extract_parameters(model_a_clone, weighting=True)
                vector_a = torch.cat(
                    (
                        params_a["weight"].transpose(1, 2).reshape(-1, cfg.n_tree),
                        params_a["bias"].view(-1, cfg.n_tree),
                        params_a["leaf"].view(-1, cfg.n_tree),
                    )
                )
                distance_matrices[i] = vector_a.T @ vector_b
                if cfg.depth > 3:
                    logger.info(
                        f"depth={cfg.depth}. Additional invariances are not considered."
                    )
                    break
            logger.info("Distance matrix is calculated")

            distance_matrix_noflip = distance_matrices[0]
            row_ind, col_ind = find_best_permutation(distance_matrix_noflip.t())

            logger.info(f"Permutation (without sign-flip): {col_ind}")

            model_a_treewise_noflip = copy.deepcopy(model_a_treewise)
            model_a_treewise_noflip.apply_permutation(col_ind)
            save_tree_model(
                cfg,
                model_a_treewise_noflip,
                model_path=f"./output/model/{strategy}_permuted_noop_{mode}_{cfg.model_a}",
            )
            if cfg.depth > 3:
                logger.info(
                    f"depth={cfg.depth}. Additional invariances are not considered."
                )
                continue

            if mode == "non_oblivious":
                # with flip
                distance_matrix, resulting_flips = torch.max(distance_matrices, dim=0)
                row_ind, col_ind = find_best_permutation(distance_matrix.t())

                logger.info(f"Permutation (with sign-flip): {col_ind}")
                logger.info(f"Flip: {resulting_flips[col_ind, row_ind]}")

                model_a_treewise_flip = copy.deepcopy(model_a_treewise)
                model_a_treewise_flip.apply_permutation(col_ind)

                for tree_index_a, resulting_flip in enumerate(
                    [
                        patterns_to_indices([all_patterns[i]])
                        for i in resulting_flips[col_ind, row_ind]
                    ],
                ):
                    tree_a = model_a_treewise_flip.trees[tree_index_a]
                    for flip_node_index in resulting_flip[0]:
                        tree_a.flip_children(flip_node_index)

                save_tree_model(
                    cfg,
                    model_a_treewise_flip,
                    model_path=f"./output/model/{strategy}_permuted_flip_non_oblivious_{cfg.model_a}",
                )
        elif mode == "oblivious":
            n_ordering = math.factorial(cfg.depth)
            distance_matrices = torch.zeros(
                (2**n_nodes, n_ordering, cfg.n_tree, cfg.n_tree)
            )
            all_ordering = list(permutations(np.arange(cfg.depth)))
            params_b = extract_parameters(model_b, weighting=True)
            vector_b = torch.cat(
                (
                    params_b["weight"].transpose(1, 2).reshape(-1, cfg.n_tree),
                    params_b["bias"].view(-1, cfg.n_tree),
                    params_b["leaf"].transpose(2, 3).reshape(-1, cfg.n_tree),
                )
            )

            for j, order in enumerate(tqdm(all_ordering, leave=False)):
                for i, flip_indices in enumerate(flip_indices_array):
                    model_a_treewise_clone = copy.deepcopy(model_a_treewise)
                    model_a_treewise_clone.apply_reordering([order] * cfg.n_tree)
                    for flip_node_index in flip_indices:
                        model_a_treewise_clone.flip_children(
                            torch.tensor([-1] * cfg.n_tree),
                            flip_node_index + 1,  # 1-index for oblivious tree
                        )
                    model_a_clone.copy_parameters(model_a_treewise_clone)
                    params_a = extract_parameters(model_a_clone, weighting=True)
                    vector_a = torch.cat(
                        (
                            params_a["weight"].transpose(1, 2).reshape(-1, cfg.n_tree),
                            params_a["bias"].view(-1, cfg.n_tree),
                            params_a["leaf"].transpose(2, 3).reshape(-1, cfg.n_tree),
                        )
                    )
                    distance_matrices[i, j] = vector_a.T @ vector_b
                    if cfg.depth > 3:
                        logger.info(
                            f"depth={cfg.depth}. Additional invariances are not considered."
                        )
                        break
                if cfg.depth > 3:
                    break

            # without operation
            distance_matrix = distance_matrices[0, 0]
            row_ind, col_ind = find_best_permutation(distance_matrix.t())

            logger.info(f"Permutation (without operation): {col_ind}")

            model_a_treewise_noop = copy.deepcopy(model_a_treewise)
            model_a_treewise_noop.apply_permutation(col_ind)
            save_tree_model(
                cfg,
                model_a_treewise_noop,
                model_path=f"./output/model/{strategy}_permuted_noop_oblivious_{cfg.model_a}",
            )
            if cfg.depth > 3:
                logger.info(
                    f"depth={cfg.depth}. Additional invariances are not considered."
                )
                continue

            # only with flip
            _distance_matrix, _resulting_flips = torch.max(distance_matrices, dim=0)
            distance_matrix, resulting_flips = _distance_matrix[0], _resulting_flips[0]
            row_ind, col_ind = find_best_permutation(distance_matrix.t())

            logger.info(f"Permutation (only with sign-flip): {col_ind}")
            logger.info(f"Flip: {resulting_flips[col_ind, row_ind]}")

            model_a_treewise_onlyflip = copy.deepcopy(model_a_treewise)
            model_a_treewise_onlyflip.apply_permutation(col_ind)
            for tree_index_a, resulting_flip in enumerate(
                [
                    patterns_to_indices([all_patterns[i]])
                    for i in resulting_flips[col_ind, row_ind]
                ],
            ):
                tree_a = model_a_treewise_onlyflip.trees[tree_index_a]
                for flip_node_index in resulting_flip[0]:
                    tree_a.flip_children(flip_depth=flip_node_index + 1)

            save_tree_model(
                cfg,
                model_a_treewise_onlyflip,
                model_path=f"./output/model/{strategy}_permuted_onlyflip_oblivious_{cfg.model_a}",
            )

            # only with reordering
            _distance_matrix, _resulting_orders = torch.max(distance_matrices, dim=1)
            distance_matrix, resulting_orders = (
                _distance_matrix[0],
                _resulting_orders[0],
            )
            row_ind, col_ind = find_best_permutation(distance_matrix.t())

            logger.info(f"Permutation (only with reordering): {col_ind}")
            logger.info(f"Order: {resulting_orders[col_ind, row_ind]}")

            model_a_treewise_onlyorder = copy.deepcopy(model_a_treewise)
            model_a_treewise_onlyorder.apply_permutation(col_ind)

            for tree_index_a, resulting_order in enumerate(
                [all_ordering[i] for i in resulting_orders[col_ind, row_ind]]
            ):
                tree_a = model_a_treewise_onlyorder.trees[tree_index_a]
                tree_a.apply_reordering(resulting_order)

            save_tree_model(
                cfg,
                model_a_treewise_onlyorder,
                model_path=f"./output/model/{strategy}_permuted_onlyorder_oblivious_{cfg.model_a}",
            )

            # flip and reordering
            reshaped_distance_matrices = distance_matrices.reshape(
                distance_matrices.shape[0] * distance_matrices.shape[1],
                distance_matrices.shape[2],
                distance_matrices.shape[3],
            )
            distance_matrix, resulting_flips_orders = torch.max(
                reshaped_distance_matrices, dim=0
            )
            resulting_idx = (
                resulting_flips_orders // distance_matrices.shape[1],
                resulting_flips_orders % distance_matrices.shape[1],
            )

            resulting_flips = resulting_idx[0]
            resulting_orders = resulting_idx[1]

            row_ind, col_ind = find_best_permutation(distance_matrix.t())

            logger.info(f"Permutation (with sign-flip and reordering): {col_ind}")
            logger.info(f"Order: {resulting_orders[col_ind, row_ind]}")
            logger.info(f"Flip: {resulting_flips[col_ind, row_ind]}")

            model_a_treewise_orderflip = copy.deepcopy(model_a_treewise)
            model_a_treewise_orderflip.apply_permutation(col_ind)

            for tree_index_a, resulting_order in enumerate(
                [all_ordering[i] for i in resulting_orders[col_ind, row_ind]]
            ):
                tree_a = model_a_treewise_orderflip.trees[tree_index_a]
                tree_a.apply_reordering(resulting_order)
            for tree_index_a, resulting_flip in enumerate(
                [
                    patterns_to_indices([all_patterns[i]])
                    for i in resulting_flips[col_ind, row_ind]
                ],
            ):
                tree_a = model_a_treewise_orderflip.trees[tree_index_a]
                for flip_node_index in resulting_flip[0]:
                    tree_a.flip_children(flip_depth=flip_node_index + 1)

            save_tree_model(
                cfg,
                model_a_treewise_orderflip,
                model_path=f"./output/model/{strategy}_permuted_orderflip_oblivious_{cfg.model_a}",
            )
        else:
            raise NotImplementedError


if __name__ == "__main__":
    main()
