import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import json
import networkx as nx
from typing import Dict, List, Any, Tuple


class Reasoning_Graph_Dataset(Dataset):
    
    def _compute_adj(self, dep_graph: List[List[int]]) -> torch.Tensor:
        """
        Convert dependency graph from child→parent form into
        a parent→child adjacency matrix.

        Args:
            dep_graph: adjacency matrix where dep_graph[i][j] = 1
                       if claim i depends on claim j (child→parent).

        Returns:
            adj_parent_to_child: torch tensor, shape [n, n],
                                 with edges in parent→child convention.
        """
        adj_child_to_parent = np.array(dep_graph)
        adj_parent_to_child = adj_child_to_parent.T
        return torch.tensor(adj_parent_to_child, dtype=torch.float)

    def _compute_ancestors(self, adj_matrix: torch.Tensor) -> torch.Tensor:
        """
        Compute the ancestor matrix of the graph using NetworkX.

        Args:
            adj_matrix: adjacency matrix in parent→child convention.

        Returns:
            ancestor_matrix: [n, n] tensor where entry (i, j) = 1 if node i
                            is an ancestor of node j, else 0.
        """
        n = adj_matrix.shape[0]
        G = nx.from_numpy_array(adj_matrix.numpy(), create_using=nx.DiGraph)

        ancestor_matrix = np.zeros((n, n), dtype=int)
        for j in G.nodes():
            for i in nx.ancestors(G, j):
                ancestor_matrix[i, j] = 1

        return torch.tensor(ancestor_matrix, dtype=torch.bool)

    def _build_features(self, claims: List[Dict[str, Any]], use_cols: List[str]) -> torch.Tensor:
        feature_list = []
        for claim in claims:
            feat_vec = []
            for col in use_cols:
                if col not in claim:
                    raise KeyError(f"Missing column {col} in claim: {claim}")
                feat_vec.append(float(claim[col]))
            feature_list.append(feat_vec)
        return torch.tensor(feature_list, dtype=torch.float)

    def _compute_labels(self, claims: List[Dict[str, Any]]) -> torch.Tensor:
        """
        Extract manual annotations as ground-truth labels.

        Args:
            claims: list of claim dictionaries.

        Returns:
            labels: tensor of shape [n], dtype long, with 0/1 annotations.
        """
        labels = [int(c["manual_annotation"]) for c in claims]
        return torch.tensor(labels, dtype=torch.long)

    def __init__(self, path: str, use_cols: list[str]):
        with open(path, 'r') as f:
            raw_data = json.load(f)

        # Store raw data for access to graph_annotations
        self.raw_data = raw_data
        data = raw_data['data']

        self.x = []
        self.y = []

        for ex in data:
            features = self._build_features(ex['claims'], use_cols)
            adj_graph = self._compute_adj(ex['dep_graph'])
            ancestors = self._compute_ancestors(adj_graph)
            labels = self._compute_labels(ex['claims'])

            self.x.append({
                "features": features,
                "adj": adj_graph,
                "ancestors": ancestors,
            })
            self.y.append(labels)

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, idx: int) -> Tuple[Dict[str, Any], torch.Tensor]:
        if isinstance(idx, slice):
            return [self[ii] for ii in range(*idx.indices(len(self)))]
        return self.x[idx], self.y[idx]