import math
import os
import random
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from functools import wraps
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from omegaconf import DictConfig
from scipy.optimize import linear_sum_assignment
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from models.mlp import ReLUMLP, SigmoidMLP
from models.non_oblivious import (DecisionListEnsemble,
                                  DecisionListEnsembleNoBinary,
                                  SoftTreeEnsemble,
                                  TreeWiseDecisionListEnsemble,
                                  TreeWiseDecisionListEnsembleNoBinary,
                                  TreeWiseSoftTreeEnsemble)
from models.oblivious import (ObliviousTreeEnsemble,
                              TreeWiseObliviousTreeEnsemble)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

torch.set_default_dtype(torch.float64)


def train(
    model: nn.Module, train_loader: DataLoader, optimizer: Optimizer, epoch: int
) -> None:
    model.train()
    correct = 0
    for batch_idx, (data, target) in enumerate(
        tqdm(train_loader, leave=False, desc="Training...")
    ):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = nn.functional.log_softmax(model(data), dim=1)
        loss = F.nll_loss(output, target)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        loss.backward()
        optimizer.step()
    acc = 100.0 * correct / len(train_loader.dataset)
    logger.info(f"Train Accuracy: {acc:.1f}%")


def test(model: SoftTreeEnsemble, test_loader: DataLoader) -> Tuple[float, float]:
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader, leave=False, desc="Testing..."):
            data, target = data.to(device), target.to(device)
            output = nn.functional.log_softmax(model(data), dim=1)
            test_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100.0 * correct / len(test_loader.dataset)
    logger.info(f"Test Loss: {test_loss:.2f}, Test Accuracy: {acc:.2f}%")
    return test_loss, acc


def initialize_and_load_model(
    model_path: str,
    input_dim: int,
    output_dim: int,
    depth: int,
    alpha: float,
    n_tree: int,
    mode: str = "",
    cpu: bool = False,
):
    if mode == "non_oblivious":
        model_class = SoftTreeEnsemble
    elif mode == "oblivious":
        model_class = ObliviousTreeEnsemble
    elif mode == "decision_list":
        model_class = DecisionListEnsemble
    elif mode == "decision_list_no_binary":
        model_class = DecisionListEnsembleNoBinary
    elif mode == "relu_mlp":
        model_class = ReLUMLP
    elif mode == "sigmoid_mlp":
        model_class = SigmoidMLP
    else:
        raise NotImplementedError

    device = "cpu" if cpu else "cuda"
    if "MLP" in model_class.__name__:
        model = model_class(input_dim, output_dim, depth, alpha, n_tree).to(device)
    else:
        model = model_class(input_dim, output_dim, depth, alpha, n_tree, device)

    map_location = "cpu" if cpu else None
    checkpoint = torch.load(model_path, map_location=map_location)

    model.load_state_dict(checkpoint)
    return model


def load_tree_models(cfg, mode, cpu=False):
    if mode == "non_oblivious":
        treewise_class = TreeWiseSoftTreeEnsemble
    elif mode == "oblivious":
        treewise_class = TreeWiseObliviousTreeEnsemble
    elif mode == "decision_list":
        treewise_class = TreeWiseDecisionListEnsemble
    elif mode == "decision_list_no_binary":
        treewise_class = TreeWiseDecisionListEnsembleNoBinary
    else:
        raise NotImplementedError

    device = "cpu" if cpu else "cuda"

    logger.info(f"loading model A (Ensemblewise, {mode})...")
    model_a = initialize_and_load_model(
        f"./output/model/{mode}_{cfg.model_a}",
        cfg.dataset.input_dim,
        cfg.dataset.output_dim,
        cfg.depth,
        cfg.alpha,
        cfg.n_tree,
        mode,
        cpu,
    )
    logger.info(f"loading model A (Treewise, {mode})...")
    model_a_treewise = treewise_class(
        cfg.dataset.input_dim,
        cfg.dataset.output_dim,
        cfg.depth,
        cfg.alpha,
        cfg.n_tree,
        device,
    )
    logger.info(f"Copying parameters...")
    model_a_treewise.copy_parameters(model_a)

    logger.info(f"loading model B (Ensemblewise, {mode})...")
    model_b = initialize_and_load_model(
        f"./output/model/{mode}_{cfg.model_b}",
        cfg.dataset.input_dim,
        cfg.dataset.output_dim,
        cfg.depth,
        cfg.alpha,
        cfg.n_tree,
        mode,
        cpu,
    )
    logger.info(f"loading model B (Treewise, {mode})...")
    model_b_treewise = treewise_class(
        cfg.dataset.input_dim,
        cfg.dataset.output_dim,
        cfg.depth,
        cfg.alpha,
        cfg.n_tree,
        device,
    )
    logger.info(f"Copying parameters...")
    model_b_treewise.copy_parameters(model_b)
    return model_a, model_a_treewise, model_b, model_b_treewise


def random_shuffle_non_oblivious(model_a_treewise, cfg):
    n_nodes = 2**cfg.depth - 1
    perm = np.arange(cfg.n_tree)
    random.shuffle(perm)
    model_a_treewise.apply_permutation(perm)
    for tree_index_a in tqdm(range(cfg.n_tree), leave=False):
        flip_indices = [i for i in range(2**cfg.depth - 1) if random.random() < 0.5]
        tree_a = model_a_treewise.trees[tree_index_a]
        for flip_node_index in flip_indices:
            tree_a.flip_children(flip_node_index)


def random_shuffle_oblivious(model_a_treewise, cfg):
    perm = np.arange(cfg.n_tree)
    random.shuffle(perm)
    model_a_treewise.apply_permutation(perm)
    for flip_depth in range(1, cfg.depth + 1, 1):
        signs = torch.randint(0, 2, (cfg.n_tree,)) * 2 - 1  # 1 or -1
        model_a_treewise.flip_children(signs, flip_depth)

    order = torch.stack([torch.randperm(cfg.depth) for _ in range(cfg.n_tree)])
    model_a_treewise.apply_reordering(order)


def random_shuffle_decision_list(model_a_treewise, cfg):
    perm = np.arange(cfg.n_tree)
    random.shuffle(perm)
    model_a_treewise.apply_permutation(perm)


def find_best_permutation(
    distance_matrix: torch.Tensor, verbose=True
) -> Tuple[np.ndarray, np.ndarray]:
    row_ind, col_ind = linear_sum_assignment(
        distance_matrix.cpu().detach().numpy(), maximize=True
    )
    original_score = distance_matrix[row_ind, row_ind].sum()
    permuted_score = distance_matrix[row_ind, col_ind].sum()
    assert original_score <= permuted_score
    if verbose:
        logger.debug(f"original: {original_score}, permuted: {permuted_score}")
    return row_ind, col_ind


def patterns_to_indices(patterns: List[bool]) -> List[int]:
    return [[i for i, value in enumerate(pattern) if value] for pattern in patterns]


def extract_parameters(model, weighting: bool = False) -> Dict[str, List[torch.Tensor]]:
    def convert_to_tensor(params: Dict[str, List[torch.Tensor]]) -> None:
        for key in params.keys():
            params[key] = torch.stack(params[key], dim=0)

    def parameter_weighting(
        params: Dict[str, List[torch.Tensor]], model_type
    ) -> Dict[str, List[torch.Tensor]]:
        if "Oblivious" in model_type:
            n_nodes = params["weight"].shape[0]
            depth = n_nodes
            weight = 2**depth
            for i in range(n_nodes):
                params["weight"][i] *= np.sqrt(weight)
                params["bias"][i] *= np.sqrt(weight)
        elif "DecisionListEnsembleNoBinary" in model_type:
            n_nodes = params["weight"].shape[0]
            depth = n_nodes
            for i in range(n_nodes):
                weight = depth - i
                params["weight"][i] *= np.sqrt(weight)
                params["bias"][i] *= np.sqrt(weight)
        elif "DecisionListEnsemble" in model_type:
            n_nodes = params["weight"].shape[0]
            depth = n_nodes
            weight = depth
            for i in range(n_nodes):
                weight = depth + 1 - i
                params["weight"][i] *= np.sqrt(weight)
                params["bias"][i] *= np.sqrt(weight)
        else:  # Perfect binary tree
            n_nodes = params["weight"].shape[0]
            depth = int(math.log2(n_nodes + 1))
            for i in range(n_nodes):
                current_depth = int(math.log2(i + 1))
                weight = 2 ** (depth - current_depth)
                params["weight"][i] *= np.sqrt(weight)
                params["bias"][i] *= np.sqrt(weight)

        return params

    model_type = type(model).__name__

    if "Ensemble" in model_type:
        expected_shapes = {
            "SoftTreeEnsemble": {
                "weight": (model.config["n_tree"], model.config["input_dim"]),
                "bias": (model.config["n_tree"],),
                "leaf": (model.config["output_dim"], model.config["n_tree"]),
            },
            "DecisionListEnsemble": {
                "weight": (model.config["n_tree"], model.config["input_dim"]),
                "bias": (model.config["n_tree"],),
                "leaf": (model.config["output_dim"], model.config["n_tree"]),
            },
            "DecisionListEnsembleNoBinary": {
                "weight": (model.config["n_tree"], model.config["input_dim"]),
                "bias": (model.config["n_tree"],),
                "leaf": (model.config["output_dim"], model.config["n_tree"]),
            },
            "ObliviousTreeEnsemble": {
                "weight": (model.config["n_tree"], model.config["input_dim"]),
                "bias": (model.config["n_tree"],),
                "leaf": (
                    model.config["output_dim"],
                    model.config["n_tree"],
                    2 ** model.config["depth"],
                ),
            },
        }
    else:
        expected_shapes = {
            "TreeWiseSoftTree": {
                "weight": (1, model.config["input_dim"]),
                "bias": (1,),
                "leaf": (model.config["output_dim"], 1),
            },
            "TreeWiseDecisionList": {
                "weight": (1, model.config["input_dim"]),
                "bias": (1,),
                "leaf": (model.config["output_dim"], 1),
            },
            "TreeWiseDecisionListEnsemble": {
                "weight": (1, model.config["input_dim"]),
                "bias": (1,),
                "leaf": (model.config["output_dim"], 1),
            },
            "TreeWiseObliviousTree": {
                "weight": (1, model.config["input_dim"]),
                "bias": (1,),
                "leaf": (model.config["output_dim"], 1, 2 ** model.config["depth"]),
            },
        }

    if model_type not in expected_shapes:
        raise ValueError(f"Unsupported model type: {model_type}")

    if "Oblivious" not in model_type:
        model.collect_parameters()
    params = defaultdict(list)

    for param in model.parameters():
        param = param.detach()
        for param_type, shape in expected_shapes[model_type].items():
            if param.shape == shape:
                params[param_type].append(param)
                break
        else:
            raise ValueError(
                f"Unexpected shape encountered for {model_type}: {param.shape}"
            )
    convert_to_tensor(params)
    if weighting:
        params = parameter_weighting(params, model_type)

    return params


def save_tree_model(cfg: DictConfig, model_a_treewise, model_path: str) -> None:
    model_type = type(model_a_treewise).__name__
    if "Oblivious" in model_type:
        model_class = ObliviousTreeEnsemble
    elif "NoBinary" in model_type:
        model_class = DecisionListEnsembleNoBinary
    elif "DecisionList" in model_type:
        model_class = DecisionListEnsemble
    else:
        model_class = SoftTreeEnsemble
    model_a_permuted = model_class(
        cfg.dataset.input_dim,
        cfg.dataset.output_dim,
        cfg.depth,
        cfg.alpha,
        cfg.n_tree,
        device="cpu",
    )
    model_a_permuted.copy_parameters(model_a_treewise)

    torch.save(model_a_permuted.state_dict(), model_path)
    logger.info(f"Model Saved: {model_path}")


def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def retry(max_attempts=3, timeout=60, initial_wait=15, backoff_factor=3):
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            attempts = 0
            wait_between = initial_wait
            while attempts < max_attempts:
                with ThreadPoolExecutor(max_workers=1) as executor:
                    future = executor.submit(func, *args, **kwargs)
                    try:
                        return future.result(timeout=timeout)
                    except TimeoutError:
                        attempts += 1
                        print(f"タイムアウトしました。リトライ回数: {attempts}")
                        time.sleep(wait_between)
                        wait_between *= backoff_factor
                    except Exception as e:
                        attempts += 1
                        print(f"エラーが発生しました: {e}。リトライ回数: {attempts}")
                        time.sleep(wait_between)
                        wait_between *= backoff_factor
            raise RuntimeError(f"最大リトライ回数 {max_attempts} に達しました。")

        return wrapper

    return decorator
