import torch
import dgl
import numpy as np
from dgl.data import DGLDataset
from rich.progress import track
from src.data.tree_transform import TreeOpTransform
from .builders_graph import get_node_builder, get_edge_builder
from src.scalers.scaler_pack import build_composite_scaler


class BaseGraphDataset(DGLDataset):
    def __init__(
        self,
        args,
        data: dict,
        data_type: str,
        in_scaler=None,
        out_scaler=None,
        y_scaler=None,
        node_builder_name: str = "rich_nodes",
        edge_builder_name: str = "rich_edges",
        use_tree_algo: bool = False
    ):
        super().__init__(name="base_graph_dataset")
        self.args  = args
        self.data  = data
        self.N     = len(data["X_nodes"])
        self.nb    = get_node_builder(node_builder_name)
        self.eb    = get_edge_builder(edge_builder_name)
        self.use_tree_algo = use_tree_algo
        self.tree_transform  = TreeOpTransform(n_levels=10, k_hop_levels=3, k_neighbors=8)

        # raw tensors
        self._nodes        = [torch.from_numpy(a).float()         for a in data["X_nodes"]]
        self._edges        = [torch.from_numpy(a.T).long()        for a in data["X_edges"]]
        self._node_tags    = [torch.from_numpy(a).long()          for a in data["X_node_tags"]]
        self._dists        = [torch.from_numpy(a).float().unsqueeze(1) for a in data["X_distances"]]
        self._obj_ids      = data["Object_ids"]
        self._cells_list   = data["Cells_list"]

        # globals scalers
        if data_type == "train":
            self.in_scaler  = build_composite_scaler("std")
            self.out_scaler = build_composite_scaler("std")

            Xs = data["X_scalars"]
            Ys = data["Y_scalars"]
            self.input_globals  = torch.from_numpy(self.in_scaler.fit_transform(Xs)).float() if Xs.size else torch.empty((0,))
            self.output_globals = torch.from_numpy(self.out_scaler.fit_transform(Ys)).float() if Ys.size else torch.empty((0,))

            self.y_scaler = build_composite_scaler(args.y_scaler_name)
            y_np_list = data["Y_fields"]
            lengths   = [arr.shape[0] for arr in y_np_list]
            y_all     = np.vstack(y_np_list)
            y_all_s   = self.y_scaler.fit_transform(y_all)
            self._Y_fields = []
            idx = 0
            for L in lengths:
                self._Y_fields.append(torch.from_numpy(y_all_s[idx: idx+L]).float())
                idx += L
        else:
            assert all(v is not None for v in (in_scaler, out_scaler, y_scaler))
            self.in_scaler, self.out_scaler, self.y_scaler = in_scaler, out_scaler, y_scaler
            Xs = data["X_scalars"]; Ys = data["Y_scalars"]
            self.input_globals  = torch.from_numpy(self.in_scaler.transform(Xs)).float() if Xs.size else torch.empty((0,))
            self.output_globals = torch.from_numpy(self.out_scaler.transform(Ys)).float() if Ys.size else torch.empty((0,))
            self._Y_fields = [torch.from_numpy(self.y_scaler.transform(arr)).float() for arr in data["Y_fields"]]

        n_sca = self.input_globals.size(1) if self.input_globals.numel() > 0 else 0
        self._node_globals = []
        for i, pos in enumerate(self._nodes):
            Nn = pos.size(0)
            if n_sca:
                g = self.input_globals[i].unsqueeze(0).repeat(Nn, 1)
            else:
                g = torch.empty((Nn, 0), dtype=torch.float32)
            self._node_globals.append(g)

        self.graphs = []
        for i in track(range(self.N), total=self.N, description=f"🚀 Building {data_type} graphs"):
            self.graphs.append(self._build_graph(i))

    def _edge_index_from_tree_algo(self, pos: torch.Tensor, eid: torch.Tensor):
        out_t = self.tree_transform.transform(pos)
        new_idx, _ = self.tree_transform.postprocess(eid, out_t)
        return new_idx[0], new_idx[1]

    def _build_graph(self, i: int) -> dgl.DGLGraph:
        pos  = self._nodes[i]
        eid  = self._edges[i]
        tags = torch.nn.functional.one_hot(self._node_tags[i], int(self._node_tags[i].max())+1).float()
        dist = self._dists[i]
        sca  = self._node_globals[i]
        y    = self._Y_fields[i]
        ctx  = {"obj_ids": self._obj_ids[i], "cells": self._cells_list[i], "args": self.args}

        # node feats
        x = self.nb(pos, tags, dist, sca, ctx)

        # graph
        if self.use_tree_algo:
            src, dst = self._edge_index_from_tree_algo(pos, eid)
        else:
            src, dst = eid[0], eid[1]

        g = dgl.graph((src, dst), num_nodes=pos.size(0))
        g.ndata["x"]   = x
        g.ndata["y"]   = y
        g.ndata["pos"] = pos
        # g.ndata["did"] = ctx["did"]
        # g.ndata["alpha"] = ctx["alpha"]
        # g.ndata["did_angles"] = ctx["did_angles"]
        # g.ndata["cos_angles"] = ctx["cos_angles"]
        # g.ndata["sin_angles"] = ctx["sin_angles"]

        # g.ndata["eta"]   = ctx["eta"]
        # g.ndata["W"]     = ctx["W"]
        # g.ndata["t_ctr"] = ctx["t_ctr"]

        # g.ndata["t_loc"] = ctx["t_loc"]
        # g.ndata["n_loc"] = ctx["n_loc"]

        # g.graph_attr = {
        #     "cos_angles": ctx["cos_angles"],
        #     "sin_angles": ctx["sin_angles"],
        #     # "did_angles": ctx["did_angles"]
        # }

        # g.graph_attr = {
        #     "centers": ctx["centers"],
        #     "radii": ctx["radii"],
        # }

        # gdge feats
        g.edata["f"] = self.eb(pos, src, dst, ctx)
        # g.edata["edge_type"] = torch.zeros(len(src), dtype=torch.long)  # 0
        return g

    def __getitem__(self, idx):
        g  = self.graphs[idx]
        ig = self.input_globals[idx]  if self.input_globals.numel()  > 0 else torch.empty((0,))
        og = self.output_globals[idx] if self.output_globals.numel() > 0 else torch.empty((0,))
        return g, ig, og, idx

    def __len__(self):
        return len(self.graphs)
