import os
import torch
import dgl
from dgl.data import DGLDataset
from rich.progress import track
from src.data.tree_error_transform import TreeAdaptiveErrorsTransform
from .builders_graph import get_edge_builder


class MultiGraphDataset(DGLDataset):
    def __init__(self, args, base_dataset, indices, epoch=None,
                 init_refine=True, stock_graph=False, error_maps=None,
                 edge_builder_name: str = "rich_edges"):
        super().__init__(name='multi_graph_dataset')
        self.args         = args
        self.base_dataset = base_dataset
        self.indices      = indices
        self.epoch        = epoch
        self.stock_graph  = stock_graph
        self.num_samples  = len(base_dataset)
        self.graphs       = []
        self.eb           = get_edge_builder(edge_builder_name)

        if init_refine:
            for i in range(self.num_samples):
                g, _, _, _ = self.base_dataset[i]
                self.graphs.append(g)
            return

        assert error_maps is not None, "Error_maps must be provided when init_refine=False!"

        error_tree = TreeAdaptiveErrorsTransform(
            n_levels=self.args.error_n_levels,
            k_hop_levels=self.args.error_k_hop_levels,
            min_points=self.args.error_min_points,
            error_threshold=self.args.error_threshold,
            type_one_side="next_second",
            type_split_axis="variance_position",
            type_new_edges=None
        )

        desc = "🚀 Modifying graphs with new edges from cached errors"
        for idx in track(range(self.num_samples), total=self.num_samples, description=desc):
            ori_graph, _, _, _ = self.base_dataset[idx]
            pos = ori_graph.ndata["pos"]
            err = error_maps[idx]

            src0, dst0 = ori_graph.edges()
            edge_index_orig = torch.stack([src0.cpu(), dst0.cpu()], dim=0)

            out_transform = error_tree.transform(pos.cpu(), err.detach().cpu().squeeze(-1))
            new_edge_index, all_edge_types, new_edges = error_tree.postprocess(edge_index_orig, out_transform)
            src_new, dst_new = new_edge_index[0], new_edge_index[1]

            new_graph = dgl.graph((src_new, dst_new), num_nodes=pos.size(0))
            new_graph.ndata["x"]   = ori_graph.ndata["x"]
            new_graph.ndata["y"]   = ori_graph.ndata["y"]
            new_graph.ndata["pos"] = pos

            ctx = {
                "args": self.args,
                # "centers": ori_graph.graph_attr["centers"],
                # "radii":   ori_graph.graph_attr["radii"]
                # "eta": ori_graph.ndata["eta"],
                # "W": ori_graph.ndata["W"],
                # "t_ctr": ori_graph.ndata["t_ctr"]
                # "did":          ori_graph.ndata["did"],
                # "alpha":        ori_graph.ndata["alpha"],
                # "cos_angles":   ori_graph.graph_attr["cos_angles"],
                # "sin_angles":   ori_graph.graph_attr["sin_angles"],
                # "kappa":      getattr(self.args, "kappa", 8.0)
                # "t_loc":          ori_graph.ndata["t_loc"],
                # "n_loc":          ori_graph.ndata["n_loc"],
            }

            new_graph.edata["f"] = self.eb(pos, src_new, dst_new, ctx)
            # new_graph.edata["edge_type"] = all_edge_types
            self.graphs.append(new_graph)

            if self.stock_graph and idx == 0:
                print(f"✈️ New edges of error control = {new_edges.shape}")
                viz_dir = os.path.join(self.args.save_path, f"{self.args.dataset_name}/{self.args.run_name}/visualisation")
                os.makedirs(viz_dir, exist_ok=True)
                torch.save(new_edges, os.path.join(viz_dir, f"sample_{self.indices[idx]}_edges_epoch_{self.epoch}.pt"))
                torch.save(err.detach().cpu(), os.path.join(viz_dir, f"sample_{self.indices[idx]}_errormap_epoch_{self.epoch}.pt"))

    def __getitem__(self, idx):
        base_graph, in_globals, out_globals, _ = self.base_dataset[idx]
        new_graph = self.graphs[idx]

        if idx == 0:
            print(f"Base graph: {base_graph.number_of_edges()} edges, New graph: {new_graph.number_of_edges()} edges")
        return base_graph, new_graph, in_globals, out_globals, idx

    def __len__(self):
        return len(self.graphs)
