import torch
import itertools
from copy import deepcopy
from collections import defaultdict
from torch_geometric.utils import to_undirected


class LeafErrors:
    def __init__(self, points, point_indices, errors, parent=None):
        self.points = points
        self.point_indices = point_indices
        self.parent = parent
        self.errors = errors

    def fill_positions(self, positions):
        positions[self.point_indices] = self.points
        positions[self.node_index] = self.pos

    def update_pos(self):
        self.pos = torch.mean(self.points, axis=0)

    def max_idx(self):
        return max(self.node_index, self.point_indices.max())

    def traverse(self):
        return [self]

    def range(self, i):
        return torch.min(self.points[:, i]), torch.max(self.points[:, i])

    def __repr__(self):
        return f"Leaf(points={self.points.shape[0]}, errors={self.errors.shape[0]})"


class Node:
    def __init__(self, parent, node_index):
        self.parent = parent
        self.node_index = node_index

    def update_pos(self):
        self.left.update_pos()
        self.right.update_pos()
        self.pos = (self.left.pos + self.right.pos) / 2.0

    def max_idx(self):
        return max(self.left.max_idx(), self.right.max_idx(), self.node_index)

    def fill_positions(self, positions):
        positions[self.node_index] = self.pos
        self.left.fill_positions(positions)
        self.right.fill_positions(positions)

    def range(self, i):
        leftrange = self.left.range(i)
        rightrange = self.right.range(i)
        return min(leftrange[0], rightrange[0]), max(leftrange[1], rightrange[1])

    @staticmethod
    def from_leaves(parent, node_index, leftleaf, rightleaf):
        result = Node(parent, node_index)
        result.left = leftleaf
        result.right = rightleaf
        leftleaf.parent = result
        rightleaf.parent = result
        return result

    def remove_child(self, child):
        if self.left == child:
            self.left = None
        elif self.right == child:
            self.right = None
        else:
            raise ValueError("Leaf not found")

    def change_child(self, old, new):
        if self.left == old:
            self.left = new
        elif self.right == old:
            self.right = new
        else:
            raise ValueError("Leaf not found")

    def empty(self):
        return self.left is None and self.right is None

    def traverse(self, parent=None):
        return itertools.chain(self.left.traverse(), self.right.traverse())

    def __repr__(self):
        return f"Node({self.left}, {self.right})"


class TreeAdaptiveErrorsTransform:
    def __init__(
            self, n_levels, k_hop_levels,
            min_points, error_threshold=None,
            type_one_side="next_second",
            type_split_axis="variance_position",
            type_new_edges="same_axis"
    ):
        self.n_levels = n_levels
        self.k_hop_levels = k_hop_levels
        self.min_points = min_points
        self.error_threshold = error_threshold
        self.type_one_side = type_one_side
        self.type_split_axis = type_split_axis
        self.type_new_edges = type_new_edges

    def split(self, leaf, axis_idx, min_points):
        # Check if the leaf has too few points to split further
        if leaf.points.shape[0] < min_points:
            return leaf, None, None, None

        # Split on highest error point
        highest_error_idx = torch.argmax(leaf.errors)
        split_point = leaf.points[highest_error_idx]
        masks = leaf.points[:, axis_idx] < split_point[axis_idx]

        # If all points fall on one side, handle this case by adjusting strategy
        if torch.all(masks) or torch.all(~masks):
            # Sort errors in descending order to get the next highest error point
            ord_error = torch.argsort(leaf.errors, descending=True)

            if self.type_one_side == "next_second":
                new_highest_error_idx = ord_error[1]

            elif self.type_one_side == "next_weighted_variance":
                weights = leaf.errors[ord_error]
                new_highest_error_idx = torch.multinomial(weights, 1).item()

            split_point = leaf.points[new_highest_error_idx]
            masks = leaf.points[:, axis_idx] < split_point[axis_idx]
            split_point_index = leaf.point_indices[new_highest_error_idx]
        else:
            split_point_index = leaf.point_indices[highest_error_idx]

        return (
            LeafErrors(leaf.points[masks], leaf.point_indices[masks], leaf.errors[masks]),
            LeafErrors(leaf.points[~masks], leaf.point_indices[~masks], leaf.errors[~masks]),
            split_point,
            split_point_index
        )

    def create_tree(self, levels, points, errors, min_points):
        root = LeafErrors(points, torch.arange(points.shape[0]), errors)
        idx = points.shape[0] - 1

        def idx_counter():
            nonlocal idx
            idx += 1
            return idx

        for _ in range(levels):
            leaves_to_split = [leaf for leaf in root.traverse() if leaf.points.shape[0] >= min_points]
            for leaf in leaves_to_split:
                if self.type_split_axis == "variance_position":
                    var = torch.var(leaf.points, axis=0)
                    axis_idx = torch.argmax(var)

                elif self.type_split_axis == "median_absolute_deviation":
                    median_values = torch.median(leaf.points, dim=0).values
                    mad = torch.median(torch.abs(leaf.points - median_values), dim=0).values
                    axis_idx = torch.argmax(mad)

                elif self.type_split_axis == "weighted_variance":
                    weighted_var = weighted_variance(leaf.points, leaf.errors)
                    axis_idx = torch.argmax(weighted_var)

                else:
                    raise ValueError(f"Unsupported split axis type: {self.type_split_axis}")

                left, right, split_point, split_point_index = self.split(leaf, axis_idx, min_points)

                if left is None or right is None:
                    continue

                new_node = Node.from_leaves(leaf.parent, idx_counter(), left, right)
                new_node.separator = split_point
                new_node.split_point_index = split_point_index
                new_node.split_axis = axis_idx

                if leaf.parent is None:
                    root = new_node
                else:
                    leaf.parent.change_child(leaf, new_node)

        for leaf in root.traverse():
            leaf.node_index = idx_counter()
        return root

    def create_edge_indices_local(self, curr_nodes, k_hop_levels, idx_str="node_index", error_threshold=None):
        result = []
        parent_nodes = []

        for node in curr_nodes:
            if node.parent is None:
                continue

            parent_node = node
            for _ in range(k_hop_levels):
                if parent_node.parent is not None:
                    parent_node = parent_node.parent

            # If the node is a leaf, apply the error threshold filtering on individual points
            if isinstance(node, LeafErrors):
                print(f"The node is a leaf.")
                # Filter point indices where the error exceeds the threshold
                if error_threshold is not None:
                    mask = node.errors >= error_threshold  # Boolean mask where errors exceed the threshold
                    filtered_indices = node.point_indices[mask]  # Only keep indices with high errors
                else:
                    filtered_indices = node.point_indices  # If no threshold, keep all indices

                if filtered_indices.shape[0] > 0:  # Ensure there are points to connect
                    edge_index = torch.stack(
                        [torch.full(filtered_indices.shape[0:1], node.parent.split_point_index), filtered_indices]
                    )
                    result.append(edge_index)

            idx_node = getattr(node, idx_str)
            idx_parent_node = getattr(parent_node, idx_str)

            # Only add edge if split happens along the same axis as the current node's split axis
            if self.type_new_edges == "same_axis":
                if node.split_axis == parent_node.split_axis:
                    edge_index = torch.tensor([idx_parent_node, idx_node]).reshape(2, 1)
                    result.append(edge_index)
            else:
                edge_index = torch.tensor([idx_parent_node, idx_node]).reshape(2, 1)
                result.append(edge_index)

            if parent_node not in parent_nodes:
                parent_nodes.append(parent_node)

        return result, parent_nodes

    def create_edge_indices(self, tree, k_hop_levels=1, error_threshold=None):
        tree = deepcopy(tree)
        result = []
        level = 1
        curr_nodes = []

        if error_threshold is not None:
            for leaf in list(tree.traverse()):
                if len(leaf.points) > 0 and leaf.parent is not None:
                    # Filter point indices where the error exceeds the threshold
                    if error_threshold is not None:
                        mask = leaf.errors >= error_threshold  # Boolean mask where errors exceed the threshold
                        filtered_indices = leaf.point_indices[mask]  # Only keep indices with high errors
                    else:
                        filtered_indices = leaf.point_indices  # If no threshold, keep all indices

                    if filtered_indices.shape[0] > 0:  # Ensure there are points to connect
                        edge_index = torch.stack(
                            [torch.full(filtered_indices.shape[0:1], leaf.parent.split_point_index), filtered_indices]
                        )
                        result.append((level, edge_index))  # Store level with edge index
                curr_nodes.append(leaf.parent)
        else:
            for leaf in list(tree.traverse()):
                if len(leaf.points) > 0 and leaf.parent is not None:
                    # Generate edge index connecting all leaf nodes
                    edge_index = torch.stack(
                        [torch.full(leaf.point_indices.shape[0:1], leaf.parent.split_point_index), leaf.point_indices]
                    )
                    result.append((level, edge_index))  # Store level with edge index
                curr_nodes.append(leaf.parent)

        while len(curr_nodes) > 1:
            level += 1
            edge_indices, curr_nodes = self.create_edge_indices_local(
                curr_nodes, k_hop_levels, idx_str="split_point_index",
                error_threshold=error_threshold
            )

            if len(edge_indices) > 0:
                for edge_index in edge_indices:
                    result.append((level, edge_index))  # Store level with edge index

        # Flatten all edges for each level
        level_to_edges = defaultdict(list)
        for lvl, edges in result:
            level_to_edges[lvl].append(edges)

        for lvl in level_to_edges:
            level_to_edges[lvl] = torch.hstack(level_to_edges[lvl])
        return level_to_edges

    def postprocess(self, edge_index, out_transform):
        tree_edge_indices = out_transform["tree_edge_indices"]
        all_edges = [tree_edge_indices[lvl] for lvl in tree_edge_indices]
        new_edges = torch.cat(all_edges, dim=1)

        # Create the edge type feature
        # original_edge_type = torch.zeros(edge_index.shape[1], 1)  # Original edges are marked as 0
        # original_edge_type = -torch.ones(edge_index.shape[1], 1)  # Original edges are marked as -1
        # augmented_edge_type = torch.ones(new_edges.shape[1], 1)  # Augmented edges are marked as 1

        original_edge_type   = torch.zeros(edge_index.shape[1], dtype=torch.long)  # 0
        augmented_edge_type  = torch.ones(new_edges.shape[1],  dtype=torch.long)   # 1

        # Concatenate edge types
        all_edge_types = torch.cat([original_edge_type, augmented_edge_type], dim=0)
        new_edge_index = torch.cat([edge_index, new_edges], dim=1)
        return new_edge_index, all_edge_types, new_edges

    def transform(self, mesh_pos, errors, ret_tree=False):
        tree = self.create_tree(self.n_levels, mesh_pos, errors, self.min_points)
        new_edge_indices = self.create_edge_indices(tree, self.k_hop_levels, self.error_threshold)

        for lvl in new_edge_indices:
            new_edge_indices[lvl] = to_undirected(new_edge_indices[lvl])

        return {
            "tree_edge_indices": new_edge_indices,
            "tree": tree if ret_tree else None,
        }

    # New method added for plotting data
    def get_plot_data(self, edge_indices):
        plot_data = []
        for lvl, edges in edge_indices.items():
            plot_data.append({
                "level": lvl,
                "edges": edges  # Edge indices with level information
            })
        return plot_data


def weighted_variance(points, errors):
    # Calculate the weighted mean along each axis
    weighted_mean = torch.sum(points * errors[:, None], dim=0) / torch.sum(errors)
    # Calculate the squared differences from the weighted mean
    squared_diff = (points - weighted_mean) ** 2
    # Compute weighted variance along each axis
    weighted_var = torch.sum(squared_diff * errors[:, None], dim=0) / torch.sum(errors)
    return weighted_var
