from abc import abstractmethod
from typing import Callable, Literal

from molecular_regression.MyPCQM4Mv2 import MyPCQM4Mv2
from ngab import BatchedSparseGraphs
from ngab import BatchedSignals
from ngab import SparseGraph
import torch
from torch_geometric.data import Data as PygData
from torch_geometric.datasets import AQSOL, ZINC
from dataclasses import dataclass


class TorchBatch:
    """
    A NamedTuple class representing a batch of data for PyTorch.
    Methods
    -------
    to(device) -> TorchBatch
        Moves all tensor attributes of the batch to the specified device.
    collate_fn(elems: list[any]) -> TorchBatch
        Abstract static method to collate a list of elements into a batch.
    __len__() -> int
        Abstract method to return the number of elements in the batch.
    """

    def to(self, device) -> "TorchBatch":
        for attr, value in self.__dict__.items():
            if hasattr(value, "to"):
                setattr(self, attr, value.to(device))
        return self
    
    @staticmethod
    @abstractmethod
    def collate_fn(elems: list[any]) -> "TorchBatch":
        pass

    @abstractmethod
    def __len__(self) -> int:
        pass


@dataclass
class MoleculesBatch(TorchBatch):
    """
    A batch of molecular data for graph neural network processing.
    Attributes:
        graphs (BatchedSparseGraphs): Batched sparse graphs representing the molecular structures.
        x (BatchedSignals): Batched signals for node features.
        pe (BatchedSignals): Batched signals for positional encodings.
        targets (torch.Tensor): Tensor containing the target values for each molecule in the batch.
    Methods:
        collate_fn(elems: list[PygData]) -> MoleculesBatch:
            Static method to collate a list of PygData objects into a MoleculesBatch.
        __len__() -> int:
            Returns the number of graphs in the batch.
    """

    graphs: BatchedSparseGraphs
    x: BatchedSignals
    x_padded: torch.Tensor
    pe: BatchedSignals
    pe_padded: torch.Tensor
    mask: torch.Tensor
    targets: torch.Tensor

    @staticmethod
    def collate_fn(elems: list[PygData]) -> "MoleculesBatch":
        graphs = BatchedSparseGraphs.from_graphs(
            [
                SparseGraph(elem.edge_index[0], elem.edge_index[1], len(elem.x))
                for elem in elems
            ]
        )
        x = BatchedSignals.from_signals([1 + elem.x for elem in elems])
        x_padded = torch.nn.utils.rnn.pad_sequence([1 + elem.x for elem in elems], batch_first=True, padding_value=0).long()
        pe = BatchedSignals.from_signals([elem.pe for elem in elems])
        pe_padded = torch.nn.utils.rnn.pad_sequence([1 + elem.pe for elem in elems], batch_first=True, padding_value=0)
        targets = torch.cat([elem.y for elem in elems])
        mask = (x_padded.squeeze() != 0).bool()

        return MoleculesBatch(graphs, x, x_padded, pe, pe_padded, mask, targets)

    def __len__(self) -> int:
        return len(self.graphs)


def ConverterPlugin(
    pe_func: Callable[[SparseGraph], torch.Tensor] | None = None,
) -> Callable[[PygData], PygData]:
    """
    Creates a converter plugin that applies a positional encoding function to a PyG data object.
    Callable[[PygData], PygData]: returns a pre_transform function for PyG datasets.
    """

    def converter_plugin(
        data: PygData,
    ) -> PygData:
        """
        Converts a PyG data object by applying a positional encoding function.

        Args:
            data (PygData): The input data object containing edge_index, x, and y attributes.
            pe_func (Callable[[SparseGraph], torch.Tensor] | None, optional): A function that takes a SparseGraph and returns positional encodings for each graph. Defaults to None.

        Returns:
            PygData: The modified data object with an added 'pe' attribute containing the positional encoding.
        """
        assert hasattr(data, "edge_index") and hasattr(data, "x") and hasattr(data, "y")
        data = PygData(edge_index=data.edge_index, x=data.x.unsqueeze(-1), y=data.y)
        if pe_func is not None:
            data.pe = pe_func(
                SparseGraph(data.edge_index[0], data.edge_index[1], len(data.x))
            )
        else:
            data.pe = torch.zeros((len(data.x), 1))
        return data

    return converter_plugin


def setup_data(
    name: Literal["AQSOL", "PCQM4Mv2", "ZINC"],
    split: Literal["train", "val", "test"],
    pe_name: str,
    pe_dim: int,
    pe_func: Callable[[SparseGraph], torch.Tensor],
    *,
    max_len: int | None = None,
) -> torch.utils.data.Dataset:
    """
    Sets up the dataset for molecular regression tasks.

    Parameters:
    - name (Literal["AQSOL", "PCQM4Mv2", "ZINC"]): The name of the dataset to use.
    - split (Literal["train", "val", "test"]): The data split to use.
    - pe_name (str): The name of the positional encoding.
    - pe_func (Callable[[SparseGraph], torch.Tensor]): A function to compute positional encodings.
    - max_len (int | None, optional): The maximum length of the dataset. Defaults to None.

    Returns:
    - torch.utils.data.Dataset: The dataset object for the specified configuration.
    
    Raises:
    - ValueError: If an unknown dataset name is provided.
    """

    dataset: torch.utils.data.Dataset
    if name == "AQSOL":
        dataset = AQSOL(
            root=f".data/AQSOL/{pe_name}", split=split, pre_transform=ConverterPlugin(pe_func)
        )
    elif name == "PCQM4Mv2":
        dataset = MyPCQM4Mv2(
            root=f".data/PCQM4Mv2/{pe_name}",pe_func=pe_func, pe_dim=pe_dim ,split=split,
        )
    elif name == "ZINC":
        def transform(data):
            data.x = data.x.squeeze(-1)
            return data
        dataset = ZINC(
            root=f".data/ZINC/{pe_name}", split=split, pre_transform=ConverterPlugin(pe_func), transform=transform, subset=True
        )
    else:
        raise ValueError(f"Unknown dataset name: {name}")

    if max_len is not None:
        return dataset[:max_len]
    else:
        return dataset

