import itertools
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

torch.set_default_dtype(torch.float64)


class ObliviousTreeEnsemble(nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        depth: int,
        alpha: float,
        n_tree: int,
        device=None,
    ):
        super(ObliviousTreeEnsemble, self).__init__()
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.config = config
        self.inner_nodes = nn.ModuleList(
            [
                nn.Linear(self.config["input_dim"], self.config["n_tree"]).to(
                    self.device
                )
                for _ in range(depth)
            ]
        )
        self.leaves = nn.Parameter(
            torch.stack(
                [
                    nn.Linear(self.config["n_tree"], self.config["output_dim"]).weight
                    for _ in range(2 ** self.config["depth"])
                ]
            ).permute(1, 2, 0)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])
        decisions = []
        pred = torch.zeros(
            x.shape[0], self.config["n_tree"], self.config["output_dim"]
        ).to(self.device)
        for depth in range(self.config["depth"]):
            decisions.append(
                F.sigmoid(
                    self.config["alpha"]
                    * (
                        torch.matmul(x, self.inner_nodes[depth].weight.t())
                        + self.inner_nodes[depth].bias
                    )
                )
            )
        for i, pattern in enumerate(itertools.product([0, 1], repeat=len(decisions))):
            term = 1
            for j, bit in enumerate(pattern):
                term *= decisions[j] if bit else (1 - decisions[j])
            term_expanded = term.unsqueeze(-1)
            leaves_expanded = self.leaves[:, :, i].transpose(0, 1).unsqueeze(0)
            pred += term_expanded * leaves_expanded

        return torch.sum(pred, axis=1)

    def copy_parameters(self, tree_wise_model):
        with torch.no_grad():
            for depth in range(self.config["depth"]):
                for tree_index in range(self.config["n_tree"]):
                    self.inner_nodes[depth].weight.data[tree_index] = (
                        tree_wise_model.trees[tree_index]
                        .inner_nodes[depth]
                        .weight.data.squeeze()
                    )
                    self.inner_nodes[depth].bias.data[tree_index] = (
                        tree_wise_model.trees[tree_index]
                        .inner_nodes[depth]
                        .bias.data.squeeze()
                    )

            for tree_index in range(self.config["n_tree"]):
                for leaf_index in range(2 ** self.config["depth"]):
                    self.leaves[:, tree_index, leaf_index] = (
                        tree_wise_model.trees[tree_index]
                        .leaves[:, :, leaf_index]
                        .squeeze()
                    )

        logger.info(
            "Parameters successfully copied: TreeWiseObliviousTreeEnsemble->ObliviousTreeEnsemble"
        )


class TreeWiseObliviousTree(nn.Module):
    def __init__(self, input_dim, output_dim, depth, alpha, tree_index, device=None):
        super(TreeWiseObliviousTree, self).__init__()
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "tree_index": tree_index,
        }
        self.inner_nodes = nn.ModuleList(
            [
                nn.Linear(self.config["input_dim"], 1).to(self.device)
                for _ in range(depth)
            ]
        )
        self.leaves = nn.Parameter(
            torch.randn(
                self.config["output_dim"],
                1,
                2 ** self.config["depth"],
            ).to(self.device)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.squeeze(x, 1).reshape(x.shape[0], self.config["input_dim"])
        treewise_pred = torch.zeros(x.shape[0], 1, self.config["output_dim"]).to(
            self.device
        )
        decisions = []
        for depth in range(self.config["depth"]):
            decisions.append(
                F.sigmoid(
                    self.config["alpha"]
                    * (
                        torch.matmul(
                            x,
                            self.inner_nodes[depth].weight.t(),
                        )
                        + self.inner_nodes[depth].bias
                    )
                )
            )
        for i, pattern in enumerate(itertools.product([0, 1], repeat=len(decisions))):
            term = 1
            for j, bit in enumerate(pattern):
                term *= decisions[j] if bit else (1 - decisions[j])
            term_expanded = term.unsqueeze(-1)
            leaves_expanded = self.leaves[:, :, i].transpose(0, 1).unsqueeze(0)
            treewise_pred += term_expanded * leaves_expanded
        return treewise_pred

    def flip_children(self, flip_depth: int) -> None:
        assert flip_depth != 0  # 1-index

        def generate_new_order(flip_depth: int, total_depth: int) -> list:
            num_leaves = 2**total_depth
            new_order = list(range(num_leaves))
            segment_size = 2 ** (total_depth - flip_depth)

            for start in range(0, num_leaves, segment_size * 2):
                for offset in range(segment_size):
                    (
                        new_order[start + offset],
                        new_order[start + offset + segment_size],
                    ) = (
                        new_order[start + offset + segment_size],
                        new_order[start + offset],
                    )

            return new_order

        node = self.inner_nodes[flip_depth - 1]
        node.weight.data = -node.weight.data
        node.bias.data = -node.bias.data

        new_order = generate_new_order(flip_depth, self.config["depth"])
        new_leaves = self.leaves.clone()
        for new_idx, old_idx in enumerate(new_order):
            new_leaves[:, :, new_idx] = self.leaves[:, :, old_idx]
        self.leaves.data = nn.Parameter(new_leaves)

    def apply_reordering(self, order: tuple[int]) -> None:
        new_inner_nodes = [self.inner_nodes[i] for i in order]
        self.inner_nodes = nn.ModuleList(new_inner_nodes)

        # Reordering the leaves
        new_leaves = torch.zeros_like(self.leaves)
        for idx, pattern in enumerate(
            itertools.product([0, 1], repeat=self.config["depth"])
        ):
            reordered_pattern = [
                pattern[order[i].item()] for i in range(self.config["depth"])
            ]
            new_index = sum(bit * 2**i for i, bit in enumerate(reordered_pattern[::-1]))
            new_leaves[:, :, new_index] = self.leaves[:, :, idx]
        self.leaves.data = nn.Parameter(new_leaves)


class TreeWiseObliviousTreeEnsemble:
    def __init__(self, input_dim, output_dim, depth, alpha, n_tree, device=None):
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.config = {
            "input_dim": input_dim,
            "output_dim": output_dim,
            "depth": depth,
            "alpha": alpha,
            "n_tree": n_tree,
        }
        self.trees = [
            TreeWiseObliviousTree(
                input_dim, output_dim, depth, alpha, tree_index, self.device
            )
            for tree_index in range(n_tree)
        ]

    def apply_permutation(self, perm: torch.Tensor) -> None:
        sorted_trees = nn.ModuleList([self.trees[i] for i in perm])
        self.trees = sorted_trees

    def apply_reordering(self, order: torch.Tensor) -> None:
        for tree_index in range(self.config["n_tree"]):
            self.trees[tree_index].apply_reordering(order[tree_index])

    def flip_children(self, signs: torch.Tensor, flip_depth: int) -> None:
        assert flip_depth != 0  # 1-index

        for tree_index in range(self.config["n_tree"]):
            if signs[tree_index] == -1:
                self.trees[tree_index].flip_children(flip_depth)

    def forward(self, x: torch.Tensor, treewise=False) -> torch.Tensor:
        treewise_output = [
            self.trees[i].forward(x) for i in range(self.config["n_tree"])
        ]
        if treewise:
            return torch.stack(treewise_output).squeeze(dim=2)
        else:
            return sum(treewise_output).squeeze(dim=1)

    def copy_parameters(self, oblivious_model):
        with torch.no_grad():
            for tree_index in range(self.config["n_tree"]):
                for depth in range(self.config["depth"]):
                    self.trees[tree_index].inner_nodes[depth].weight.data = (
                        oblivious_model.inner_nodes[depth]
                        .weight.data[tree_index]
                        .unsqueeze(1)
                    ).t()
                    self.trees[tree_index].inner_nodes[depth].bias.data = (
                        oblivious_model.inner_nodes[depth]
                        .bias.data[tree_index]
                        .unsqueeze(0)
                    )

            for tree_index in range(self.config["n_tree"]):
                for leaf_index in range(2 ** self.config["depth"]):
                    self.trees[tree_index].leaves[:, :, leaf_index] = (
                        oblivious_model.leaves[:, tree_index, leaf_index].unsqueeze(1)
                    )

        logger.info(
            "Parameters successfully copied: ObliviousTreeEnsemble->TreeWiseObliviousTreeEnsemble"
        )


if __name__ == "__main__":
    input_dim = 2
    output_dim = 10
    depth = 3
    alpha = 1.0
    n_tree = 100

    # device = "cuda" if torch.cuda.is_available() else "cpu"
    device = "cpu"
    x = torch.Tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]]).to(device)

    # ObliviousTreeEnsemble -> TreeWiseObliviousTreeEnsemble
    ote = ObliviousTreeEnsemble(input_dim, output_dim, depth, alpha, n_tree, device)
    twote = TreeWiseObliviousTreeEnsemble(input_dim, output_dim, depth, alpha, n_tree, device)

    ote.copy_parameters(twote)

    assert torch.isclose(ote.forward(x), twote.forward(x)).all()
    logger.info(
        f"Prediction matched: ObliviousTreeEnsemble <--> TreeWiseObliviousTreeEnsemble"
    )

    # TreeWiseObliviousTreeEnsemble -> ObliviousTreeEnsemble
    ote = ObliviousTreeEnsemble(input_dim, output_dim, depth, alpha, n_tree, device)

    twote.copy_parameters(ote)

    assert torch.isclose(ote.forward(x), twote.forward(x)).all()
    logger.info(
        f"Prediction matched: ObliviousTreeEnsemble <--> TreeWiseObliviousTreeEnsemble"
    )

    # Permutation
    perm = torch.randperm(n_tree)
    before = twote.forward(x)
    twote.apply_permutation(perm)
    after = twote.forward(x)
    assert torch.isclose(before, after).all(), f"{before}, {after}"
    logger.info(f"Prediction matched: Before Permutation <--> After Permutation")

    # Reordering
    before = twote.forward(x)
    order = torch.stack([torch.randperm(depth) for _ in range(n_tree)])
    twote.apply_reordering(order)
    after = twote.forward(x)
    assert torch.isclose(before, after).all(), f"{before}, {after}"
    logger.info(f"Prediction matched: Before Reordering <--> After Reordering")

    # SignFlip
    before = twote.forward(x)
    signs = (
        torch.randint(0, 2, (n_tree,)) * 2 - 1
    )  # 0または1を2倍してから1を引くことで、1または-1を得る
    for flip_depth in range(1, depth + 1, 1):
        twote.flip_children(signs, flip_depth)
        after = twote.forward(x)
        assert torch.isclose(before, after).all(), f"{before}, {after}"
        logger.info(
            f"Prediction matched (flip_depth={flip_depth}): Before SignFlip <--> After SignFlip"
        )

    for depth in range(1, 4, 1):
        model = ObliviousTreeEnsemble(44, 2, depth, alpha, 256, device)
        print(sum(p.numel() for p in model.parameters()))
