import os.path as osp
from typing import Callable, Optional

import numpy as np
import torch
from torch_geometric.data import InMemoryDataset

from temporal_graph.data import TemporalData


class STARDataset(InMemoryDataset):
    def __init__(
        self,
        root: str,
        name: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ):
        self.name = name.lower()
        assert self.name in ['dblp3', 'dblp5', 'reddit', 'brain']

        super().__init__(root, transform, pre_transform,
                         force_reload=force_reload)
        self.load(self.processed_paths[0], data_cls=TemporalData)

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self) -> str:
        return f'{self.name}.npz'

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def download(self):
        pass

    def process(self):
        file = np.load(self.raw_paths[0])

        x = file['attmats']  # (N, T, D)
        y = file['labels']  # (N, C)
        adjs = file['adjs']  # (T, N, N)

        x = torch.from_numpy(x).to(torch.float)
        y = torch.from_numpy(y.argmax(1)).to(torch.long)
        t = []
        src = []
        dst = []
        for i, adj in enumerate(adjs):
            row, col = adj.nonzero()
            src.append(torch.from_numpy(row).to(torch.long))
            dst.append(torch.from_numpy(col).to(torch.long))
            t.append(torch.full((src[-1].size(0), ), i, dtype=torch.long))
        t = torch.cat(t)
        src = torch.cat(src)
        dst = torch.cat(dst)
        data = TemporalData(src=src, dst=dst, t=t, x=x, y=y)

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        self.save([data], self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}()'
