# Copyright (c) 2023 Merantix Momentum GmbH

# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:

# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# Modified by <Anonymous Authors>, <2026>

import torch
import itertools
from copy import deepcopy
from collections import defaultdict
from torch_geometric.utils import to_undirected


class Leaf:
    def __init__(self, points, point_indices, parent=None):
        self.points = points
        self.point_indices = point_indices
        self.parent = parent

    def fill_positions(self, positions):
        positions[self.point_indices] = self.points
        positions[self.node_index] = self.pos

    def update_pos(self):
        self.pos = torch.mean(self.points, axis=0)

    def max_idx(self):
        return max(self.node_index, self.point_indices.max())

    def traverse(self):
        return [self]

    def range(self, i):
        return torch.min(self.points[:, i]), torch.max(self.points[:, i])

    def __repr__(self):
        return f"Leaf({self.points.shape[0]})"


class Node:
    def __init__(self, parent, node_index):
        self.parent = parent
        self.node_index = node_index

    def update_pos(self):
        self.left.update_pos()
        self.right.update_pos()
        self.pos = (self.left.pos + self.right.pos) / 2.0

    def max_idx(self):
        return max(self.left.max_idx(), self.right.max_idx(), self.node_index)

    def fill_positions(self, positions):
        positions[self.node_index] = self.pos
        self.left.fill_positions(positions)
        self.right.fill_positions(positions)

    def range(self, i):
        leftrange = self.left.range(i)
        rightrange = self.right.range(i)
        return min(leftrange[0], rightrange[0]), max(leftrange[1], rightrange[1])

    @staticmethod
    def from_leaves(parent, node_index, leftleaf, rightleaf):
        result = Node(parent, node_index)
        result.left = leftleaf
        result.right = rightleaf
        leftleaf.parent = result
        rightleaf.parent = result
        return result

    def remove_child(self, child):
        if self.left == child:
            self.left = None
        elif self.right == child:
            self.right = None
        else:
            raise ValueError("Leaf not found")

    def change_child(self, old, new):
        if self.left == old:
            self.left = new
        elif self.right == old:
            self.right = new
        else:
            raise ValueError("Leaf not found")

    def empty(self):
        return self.left is None and self.right is None

    def traverse(self, parent=None):
        return itertools.chain(self.left.traverse(), self.right.traverse())

    def __repr__(self):
        return f"Node({self.left}, {self.right})"


class TreeOpTransform:
    def __init__(self, n_levels, k_hop_levels, k_neighbors):
        self.n_levels = n_levels
        self.k_hop_levels = k_hop_levels
        self.k_neighbors = k_neighbors

    def split(self, leaf, i):
        if leaf.points.numel() == 0:
            return None, None, None, None

        col = leaf.points[:, i]
        median = torch.nanmedian(col)

        mask_left = (col < median) & ~torch.isnan(col)
        mask_right = ~mask_left  # >= median + NaN

        left = leaf.points[mask_left]
        right = leaf.points[mask_right]
        left_idx = leaf.point_indices[mask_left]
        right_idx = leaf.point_indices[mask_right]

        if left.shape[0] == 0 or right.shape[0] == 0:
            return None, None, None, None

        dist = torch.abs(leaf.points - torch.nanmean(leaf.points, dim=0))
        dist_sq_sum = torch.nansum(dist**2, dim=1)
        split_point_index = leaf.point_indices[torch.argmin(dist_sq_sum)]

        return Leaf(left, left_idx), Leaf(right, right_idx), median, split_point_index

    def create_tree(self, levels, points):
        tree = Leaf(points, torch.arange(points.shape[0]))
        idx = points.shape[0] - 1

        def idx_counter():
            nonlocal idx
            idx += 1
            return idx

        for _ in range(levels):
            for leaf in list(tree.traverse()):
                leftleaf, rightleaf, median, split_point_index, i = self.split_var(leaf)

                # Handle case where split_var returns None
                if leftleaf is None or rightleaf is None:
                    continue

                next_node = Node.from_leaves(leaf.parent, idx_counter(), leftleaf, rightleaf)
                next_node.separator = median
                next_node.split_point_index = split_point_index
                next_node.split_axis = i

                if leaf.parent is None:
                    tree = next_node
                else:
                    leaf.parent.change_child(leaf, next_node)

        # Fill in last indices
        for leaf in list(tree.traverse()):
            leaf.node_index = idx_counter()

        return tree

    def create_edge_indices_local(self, curr_nodes, k_hop_levels, idx_str="node_index"):
        result = []
        parent_nodes = []

        for node in curr_nodes:
            if node.parent is None:
                # Skip root node
                continue

            parent_node = node
            for _ in range(k_hop_levels):
                # Go up the tree
                if parent_node.parent is not None:
                    parent_node = parent_node.parent

            idx_node = getattr(node, idx_str)
            idx_parent_node = getattr(parent_node, idx_str)

            edge_index = torch.tensor([idx_parent_node, idx_node]).reshape(2, 1)
            result.append(edge_index)

            if parent_node not in parent_nodes:
                parent_nodes.append(parent_node)

        return result, parent_nodes

    def create_edge_indices_simple(self, tree, k_hop_levels=1):
        # Simply use the median node to aggregate information across levels
        # do not add any nodes to the graph
        tree = deepcopy(tree)
        result = []
        level = 1
        curr_nodes = []

        for leaf in list(tree.traverse()):
            if len(leaf.points) > 0 and leaf.parent is not None:
                # Generate edge index connecting all leaf nodes
                # base connections
                edge_index = torch.stack(
                    [torch.full(leaf.point_indices.shape[0:1], leaf.parent.split_point_index), leaf.point_indices]
                )
                result.append((level, edge_index))  # Store level with edge index
            curr_nodes.append(leaf.parent)

        while len(curr_nodes) > 1:
            level += 1
            edge_indices, curr_nodes = self.create_edge_indices_local(
                curr_nodes, k_hop_levels, idx_str="split_point_index"
            )

            if len(edge_indices) > 0:
                for edge_index in edge_indices:
                    result.append((level, edge_index))  # Store level with edge index

        # Flatten all edges for each level
        level_to_edges = defaultdict(list)
        for lvl, edges in result:
            level_to_edges[lvl].append(edges)

        for lvl in level_to_edges:
            level_to_edges[lvl] = torch.hstack(level_to_edges[lvl])
        return level_to_edges

    def create_edge_indices(self, tree, k_hop_levels=1):
        tree = deepcopy(tree)
        result = defaultdict(list)
        level = 1
        curr_nodes = []

        for leaf in list(tree.traverse()):
            if len(leaf.points) > 0:
                # Generate edge index connecting all leaf nodes
                # base connections
                edge_index = torch.stack(
                    [torch.full(leaf.point_indices.shape[0:1], leaf.node_index), leaf.point_indices]
                )
                result[level].append(edge_index)
            curr_nodes.append(leaf)

        while len(curr_nodes) > 1:
            level += 1
            edge_indices, curr_nodes = self.create_edge_indices_local(curr_nodes, k_hop_levels)

            if len(edge_indices) > 0:
                result[level] = edge_indices

        for lvl in result:
            result[lvl] = torch.hstack(result[lvl])
        return result

    def split_var(self, leaf):
        if leaf.points.shape[0] < 2:
            return None, None, None, None, None

        clean_points = leaf.points[~torch.isnan(leaf.points).any(dim=1)]
        if clean_points.shape[0] < 2:
            return None, None, None, None, None
        var = torch.var(clean_points, dim=0, correction=0)

        i = int(torch.argmax(torch.nan_to_num(var, nan=-float('inf'))))
        return *self.split(leaf, i), i

    def postprocess(self, edge_index, out_transform):
        # Aggreate edge indices
        tree_edge_indices = out_transform["tree_edge_indices"]
        all_edges = [tree_edge_indices[lvl] for lvl in tree_edge_indices]
        new_edges = torch.cat(all_edges, dim=1)

        new_edge_index = torch.cat([edge_index, new_edges], dim=1)

        return new_edge_index, new_edges

    def transform(self, mesh_pos, ret_tree=False):
        tree = self.create_tree(self.n_levels, mesh_pos)
        edge_indices = self.create_edge_indices_simple(tree, self.k_hop_levels)

        # To undirected
        for lvl in edge_indices:
            edge_indices[lvl] = to_undirected(edge_indices[lvl])

        return {
            "tree_edge_indices": edge_indices,
            "tree": tree if ret_tree else None,
        }

    # New method added for plotting data
    def get_plot_data(self, edge_indices):
        plot_data = []
        for lvl, edges in edge_indices.items():
            plot_data.append({
                "level": lvl,
                "edges": edges  # Edge indices with level information
            })
        return plot_data
