import itertools
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

torch.set_default_dtype(torch.float64)


class InnerNode:
    def __init__(self, config, depth, device=None):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = config
        self.leaf = False
        self.fc = nn.Linear(
            self.config["input_dim"], self.config["n_tree"], bias=True
        ).to(self.device)
        self.prob = None
        self.path_prob = None
        self.left = None
        self.right = None
        self.leaf_accumulator = []

        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["depth"]:
            self.left = InnerNode(self.config, depth + 1, self.device)
            self.right = InnerNode(self.config, depth + 1, self.device)
        else:
            self.left = LeafNode(self.config, self.device)
            self.right = LeafNode(self.config, self.device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # decision function
        return F.sigmoid(
            self.config["alpha"] * (torch.matmul(x, self.fc.weight.t()) + self.fc.bias)
        )

    def calc_prob(self, x, path_prob):
        self.prob = self.forward(x)  # probability of selecting right node
        path_prob = path_prob.to(self.device)  # path_prob: [batch_size, n_tree]
        self.path_prob = path_prob
        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))
        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)
        self.leaf_accumulator.extend(left_leaf_accumulator)
        self.leaf_accumulator.extend(right_leaf_accumulator)
        return self.leaf_accumulator

    def reset(self):
        self.leaf_accumulator = []
        self.penalties = []
        self.left.reset()
        self.right.reset()


class AsymInnerNode(InnerNode):
    def __init__(self, config, depth, device=None, no_binary=False):
        self.no_binary = no_binary
        super(AsymInnerNode, self).__init__(config, depth, device)
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["depth"]:
            self.left = AsymInnerNode(
                self.config, depth + 1, self.device, self.no_binary
            )
            self.right = LeafNode(self.config, self.device)
        else:
            if self.no_binary:
                self.left = EmptyNode(self.config, self.device)
            else:
                self.left = LeafNode(self.config, self.device)
            self.right = LeafNode(self.config, self.device)


class LeafNode:
    def __init__(self, config, device=None):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = config
        self.leaf = True
        self.param = nn.Parameter(
            (nn.Linear(self.config["n_tree"], self.config["output_dim"]).weight).to(
                self.device
            )
        )

    def forward(self) -> nn.Parameter:
        return self.param

    def calc_prob(self, x, path_prob):
        path_prob = path_prob.to(self.device)  # [batch_size, n_tree]

        Q = self.forward()
        Q = Q.expand(
            (path_prob.size()[0], self.config["output_dim"], self.config["n_tree"])
        )  # -> [batch_size, n_class, n_tree]
        return [[path_prob, Q]]

    def reset(self):
        pass


class EmptyNode(LeafNode):
    def __init__(self, config, device=None):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = config
        self.leaf = True
        self.param = nn.Parameter(
            torch.zeros(self.config["output_dim"], self.config["n_tree"]).to(
                self.device
            ),
            requires_grad=False,
        )


class SoftTreeEnsemble(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        depth: int,
        alpha: float,
        n_tree: int,
        device=None,
    ):
        super(SoftTreeEnsemble, self).__init__()
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.config = config
        self.root = InnerNode(config, depth=1, device=self.device)
        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.left)
                nodes.append(node.right)
                self.module_list.append(fc)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], self.config["n_tree"]))
        """
        2**depth個のListの中に、流れ込んだ確率と予測値が入っている
        > len(leaf_accumulator) -> 32 (2**depth)
        > len(leaf_accumulator[0]) -> 2 (葉に流れ込んだ確率と予測値)
        > leaf_accumulator[0][0].shape -> [512, 16] (batch_size, n_tree)
        > leaf_accumulator[0][1].shape -> [512, 10, 16] (batch_size, n_class, n_tree)
        """

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"]).to(self.device)
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        self.root.reset()
        return pred

    def copy_parameters(self, treewise_soft_tree_ensemble):
        def copy_params_from_tree_wise(soft_tree_node, tree_wise_node, tree_index):
            if not soft_tree_node.leaf:
                soft_tree_node.fc.weight.data[
                    tree_index
                ] = tree_wise_node.fc.weight.data.squeeze()
                soft_tree_node.fc.bias.data[
                    tree_index
                ] = tree_wise_node.fc.bias.data.squeeze()

                copy_params_from_tree_wise(
                    soft_tree_node.left, tree_wise_node.left, tree_index
                )
                copy_params_from_tree_wise(
                    soft_tree_node.right, tree_wise_node.right, tree_index
                )
            else:
                soft_tree_node.param.data[
                    :, tree_index
                ] = tree_wise_node.param.data.squeeze()

        for tree_index in range(len(treewise_soft_tree_ensemble.trees)):
            copy_params_from_tree_wise(
                self.root,
                treewise_soft_tree_ensemble.trees[tree_index].root,
                tree_index,
            )
        logger.info(
            f"Parameters are Copied Successfully ({type(treewise_soft_tree_ensemble).__name__}->{type(self).__name__})"
        )


class DecisionListEnsemble(SoftTreeEnsemble):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        depth: int,
        alpha: float,
        n_tree: int,
        device=None,
    ):
        super(DecisionListEnsemble, self).__init__(
            input_dim, output_dim, depth, alpha, n_tree, device
        )
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.config = config
        self.root = AsymInnerNode(config, depth=1, device=self.device, no_binary=False)
        self.collect_parameters()


class DecisionListEnsembleNoBinary(SoftTreeEnsemble):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        depth: int,
        alpha: float,
        n_tree: int,
        device=None,
    ):
        super(DecisionListEnsembleNoBinary, self).__init__(
            input_dim, output_dim, depth, alpha, n_tree, device
        )
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.config = config
        self.root = AsymInnerNode(config, depth=1, device=self.device, no_binary=True)
        self.collect_parameters()


class TreeWiseInnerNode:
    def __init__(
        self,
        config,
        depth,
        tree_index,
        device,
    ):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = config
        self.leaf = False
        self.tree_index = tree_index
        self.fc = nn.Linear(self.config["input_dim"], 1, bias=True).to(self.device)
        self.prob = None
        self.path_prob = None
        self.left = None
        self.right = None
        self.leaf_accumulator = []
        self.build_child(depth)

    def build_child(self, depth):
        if depth < self.config["depth"]:
            self.left = TreeWiseInnerNode(
                self.config, depth + 1, self.tree_index, self.device
            )
            self.right = TreeWiseInnerNode(
                self.config, depth + 1, self.tree_index, self.device
            )
        else:
            self.left = TreeWiseLeafNode(self.config, self.tree_index, self.device)
            self.right = TreeWiseLeafNode(self.config, self.tree_index, self.device)

    def forward(self, x):  # decision function
        return F.sigmoid(
            self.config["alpha"] * (torch.matmul(x, self.fc.weight.t()) + self.fc.bias)
        )

    def calc_prob(self, x, path_prob):
        self.prob = self.forward(x)  # probability of selecting right node
        path_prob = path_prob.to(self.device)  # path_prob: [batch_size, 1]
        self.path_prob = path_prob
        left_leaf_accumulator = self.left.calc_prob(x, path_prob * (1 - self.prob))
        right_leaf_accumulator = self.right.calc_prob(x, path_prob * self.prob)
        self.leaf_accumulator.extend(left_leaf_accumulator)
        self.leaf_accumulator.extend(right_leaf_accumulator)
        return self.leaf_accumulator

    def reset(self):
        self.leaf_accumulator = []
        self.penalties = []
        self.left.reset()
        self.right.reset()


class TreeWiseAsymInnerNode(TreeWiseInnerNode):
    def __init__(self, config, depth, tree_index, device, no_binary=False):
        super(TreeWiseAsymInnerNode, self).__init__(config, depth, tree_index, device)
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.no_binary = no_binary
        self.new_build_child(depth)

    def new_build_child(self, depth):
        if depth < self.config["depth"]:
            self.left = TreeWiseAsymInnerNode(
                self.config, depth + 1, self.tree_index, self.device, self.no_binary
            )
            self.right = TreeWiseLeafNode(self.config, self.tree_index, self.device)
        else:
            if self.no_binary:
                self.left = TreeWiseEmptyNode(self.config, self.tree_index, self.device)
            else:
                self.left = TreeWiseLeafNode(self.config, self.tree_index, self.device)
            self.right = TreeWiseLeafNode(self.config, self.tree_index, self.device)


class TreeWiseLeafNode:
    def __init__(self, config, tree_index, device):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = config
        self.leaf = True
        self.tree_index = tree_index
        self.param = nn.Parameter(
            torch.randn(self.config["output_dim"], 1).to(self.device)
        )

    def forward(self):
        return self.param

    def calc_prob(self, x, path_prob):
        path_prob = path_prob.to(self.device)  # [batch_size, 1]

        Q = self.forward()
        Q = Q.expand(
            (path_prob.size()[0], self.config["output_dim"], 1)
        )  # -> [batch_size, n_class, 1]
        return [[path_prob, Q]]

    def reset(self):
        pass


class TreeWiseEmptyNode(TreeWiseLeafNode):
    def __init__(self, config, tree_index, device):
        super(TreeWiseEmptyNode, self).__init__(config, tree_index, device)
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.leaf = True
        self.param = nn.Parameter(
            torch.zeros(self.config["output_dim"], 1).to(self.device),
            requires_grad=False,
        )


class TreeWiseSoftTree(nn.Module):
    def __init__(self, input_dim, output_dim, depth, alpha, tree_index, device=None):
        super(TreeWiseSoftTree, self).__init__()
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "tree_index": tree_index,
        }
        self.root = TreeWiseInnerNode(
            self.config, depth=1, tree_index=tree_index, device=self.device
        )
        self.collect_parameters()

    def collect_parameters(self):
        nodes = [self.root]
        self.module_list = nn.ModuleList()
        self.param_list = nn.ParameterList()
        while nodes:
            node = nodes.pop(0)
            if node.leaf:
                param = node.param
                self.param_list.append(param)
            else:
                fc = node.fc
                nodes.append(node.left)
                nodes.append(node.right)
                self.module_list.append(fc)

    def flip_children(self, flip_node_index):
        queue = [(self.root, 0)]
        while queue:
            node, current_index = queue.pop(0)
            if current_index == flip_node_index:
                node.fc.weight = torch.nn.Parameter(-node.fc.weight)
                node.fc.bias = torch.nn.Parameter(-node.fc.bias)
                node.left, node.right = node.right, node.left
                break
            if node.left is not None:
                queue.append((node.left, current_index * 2 + 1))
            if node.right is not None:
                queue.append((node.right, current_index * 2 + 2))

    def forward(self, x):
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])

        path_prob_init = torch.Tensor(torch.ones(x.shape[0], 1))
        """
        2**depth個のListの中に、流れ込んだ確率と予測値が入っている
        > len(leaf_accumulator) -> 32 (2**depth)
        > len(leaf_accumulator[0]) -> 2 (葉に流れ込んだ確率と予測値)
        > leaf_accumulator[0][0].shape -> [512, 16] (batch_size, 1)
        > leaf_accumulator[0][1].shape -> [512, 10, 16] (batch_size, n_class, 1)
        """

        leaf_accumulator = self.root.calc_prob(x, path_prob_init)
        pred = torch.zeros(x.shape[0], self.config["output_dim"]).to(self.device)
        for i, (path_prob, Q) in enumerate(leaf_accumulator):  # 2**depth loop
            pred += torch.sum(path_prob.unsqueeze(1) * Q, dim=2)

        self.root.reset()
        return pred


class TreeWiseDecisionList(TreeWiseSoftTree):
    def __init__(self, input_dim, output_dim, depth, alpha, tree_index, device=None):
        super(TreeWiseDecisionList, self).__init__(
            input_dim, output_dim, depth, alpha, tree_index, device
        )
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.root = TreeWiseAsymInnerNode(
            self.config,
            depth=1,
            tree_index=tree_index,
            device=self.device,
            no_binary=False,
        )
        self.collect_parameters()


class TreeWiseDecisionListNoBinary(TreeWiseSoftTree):
    def __init__(self, input_dim, output_dim, depth, alpha, tree_index, device=None):
        super(TreeWiseDecisionListNoBinary, self).__init__(
            input_dim, output_dim, depth, alpha, tree_index, device
        )
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.root = TreeWiseAsymInnerNode(
            self.config,
            depth=1,
            tree_index=tree_index,
            device=self.device,
            no_binary=True,
        )
        self.collect_parameters()


class TreeWiseSoftTreeEnsemble:
    def __init__(self, input_dim, output_dim, depth, alpha, n_tree, device=None):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.trees = nn.ModuleList(
            [
                TreeWiseSoftTree(
                    input_dim=self.config["input_dim"],
                    output_dim=self.config["output_dim"],
                    depth=self.config["depth"],
                    alpha=self.config["alpha"],
                    tree_index=i,
                    device=self.device,
                )
                for i in range(self.config["n_tree"])
            ]
        )

    def forward(self, x, treewise=False):
        treewise_output = [
            self.trees[i].forward(x) for i in range(self.config["n_tree"])
        ]
        if treewise:
            return torch.stack(treewise_output)
        else:
            return sum(treewise_output)

    def copy_parameters(self, soft_tree_ensemble):
        def set_node_params(tree_wise_node, soft_tree_node, tree_index):
            # 内部ノードの場合
            if not soft_tree_node.leaf:
                tree_wise_node.fc.weight = nn.Parameter(
                    soft_tree_node.fc.weight[tree_index].unsqueeze(0)
                )
                tree_wise_node.fc.bias = nn.Parameter(
                    soft_tree_node.fc.bias[tree_index].unsqueeze(0)
                )

                set_node_params(tree_wise_node.left, soft_tree_node.left, tree_index)
                set_node_params(tree_wise_node.right, soft_tree_node.right, tree_index)
            else:
                if hasattr(tree_wise_node, "param"):
                    tree_wise_node.param = nn.Parameter(
                        soft_tree_node.param[:, tree_index].unsqueeze(1)
                    )

        for tree_index in range(len(self.trees)):
            set_node_params(
                self.trees[tree_index].root, soft_tree_ensemble.root, tree_index
            )
        logger.info(
            f"Parameters are Copied Successfully ({type(soft_tree_ensemble).__name__}->{type(self).__name__})"
        )

    def flip_children(self, signs, node_index):
        for tree_index in range(self.config["n_tree"]):
            if signs[tree_index] == -1:
                queue = [(self.trees[tree_index].root, 0)]
                while queue:
                    node, current_index = queue.pop(0)
                    if current_index == node_index:
                        node.fc.weight = torch.nn.Parameter(-node.fc.weight)
                        node.fc.bias = torch.nn.Parameter(-node.fc.bias)
                        node.left, node.right = node.right, node.left
                        break
                    if node.left is not None:
                        queue.append((node.left, current_index * 2 + 1))
                    if node.right is not None:
                        queue.append((node.right, current_index * 2 + 2))

    def apply_permutation(self, perm):
        sorted_trees = nn.ModuleList([self.trees[i] for i in perm])
        self.trees = sorted_trees


class TreeWiseDecisionListEnsemble(TreeWiseSoftTreeEnsemble):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        depth: int,
        alpha: float,
        n_tree: int,
        device=None,
    ):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.trees = nn.ModuleList(
            [
                TreeWiseDecisionList(
                    input_dim=self.config["input_dim"],
                    output_dim=self.config["output_dim"],
                    depth=self.config["depth"],
                    alpha=self.config["alpha"],
                    tree_index=i,
                    device=self.device,
                )
                for i in range(self.config["n_tree"])
            ]
        )


class TreeWiseDecisionListEnsembleNoBinary(TreeWiseSoftTreeEnsemble):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        depth: int,
        alpha: float,
        n_tree: int,
        device=None,
    ):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.trees = nn.ModuleList(
            [
                TreeWiseDecisionListNoBinary(
                    input_dim=self.config["input_dim"],
                    output_dim=self.config["output_dim"],
                    depth=self.config["depth"],
                    alpha=self.config["alpha"],
                    tree_index=i,
                    device=self.device,
                )
                for i in range(self.config["n_tree"])
            ]
        )


if __name__ == "__main__":
    input_dim = 2
    output_dim = 10
    depth = 3
    alpha = 1.0
    n_tree = 100

    device = "cuda" if torch.cuda.is_available() else "cpu"
    x = torch.Tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]]).to(device)

    for ensemble_class, treewise_class in zip(
        (SoftTreeEnsemble, DecisionListEnsemble, DecisionListEnsembleNoBinary),
        (
            TreeWiseSoftTreeEnsemble,
            TreeWiseDecisionListEnsemble,
            TreeWiseDecisionListEnsembleNoBinary,
        ),
    ):
        logger.info(f"{ensemble_class}, {treewise_class}")
        twste = treewise_class(input_dim, output_dim, depth, alpha, n_tree, device)
        ste = ensemble_class(input_dim, output_dim, depth, alpha, n_tree, device)
        ste.copy_parameters(twste)
        assert torch.isclose(ste.forward(x), twste.forward(x)).all()

        twste = treewise_class(input_dim, output_dim, depth, alpha, n_tree, device)
        ste = ensemble_class(input_dim, output_dim, depth, alpha, n_tree, device)
        twste.copy_parameters(ste)
        assert torch.isclose(ste.forward(x), twste.forward(x)).all()

        if type(ensemble_class).__name__ == "SoftTreeEnsemble":
            signs = torch.Tensor([-1] * 100)
            before_flip = twste.forward(x)
            twste.flip_children(signs, node_index=0)
            twste.flip_children(signs, node_index=1)
            twste.flip_children(signs, node_index=2)
            after_flip = twste.forward(x)
            assert torch.isclose(before_flip, after_flip)

    for depth in range(1, 4, 1):
        model = SoftTreeEnsemble(44, 2, depth, alpha, 256, device)
        print(sum(p.numel() for p in model.parameters()))
