import os
import pathlib
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils.sparse import dense_to_sparse

import utils.utils as utils
from datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos
from fast_jtnn import *


class ZINCDataset(InMemoryDataset):
    def __init__(self, dataset_name, split, root, transform=None, pre_transform=None, pre_filter=None):
        self.dataset_name = dataset_name
        self.split = split
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['train.pt', 'valid.pt', 'test.pt']

    @property
    def processed_file_names(self):
        return [self.split + '.pt']


    def process(self):
        file_idx = {'train': 0, 'val': 1, 'test': 2}
        raw_dataset = torch.load(self.raw_paths[file_idx[self.split]])

        data_list = []
        # x, edge, mol_tree, target_data["targets"][i], i
        for raw_data in raw_dataset:
            x = torch.from_numpy(raw_data[0]).long()
            num_nodes = len(x)
            X = torch.nn.functional.one_hot(x, num_classes = 780).float()
            y = torch.tensor([[raw_data[3]["plogp"], raw_data[3]["qed"]]]).float()

            adj = torch.from_numpy(raw_data[1])
            lower_tri_mask = torch.tril(torch.ones(num_nodes, num_nodes), diagonal = 0).bool()
            assert (adj[lower_tri_mask] == 0).all()
            adj = adj + adj.T
            edge_index, edge_attr = dense_to_sparse(adj)
            edge_attr = torch.nn.functional.one_hot(edge_attr.long(), num_classes = 2).float()

            mol_tree = raw_data[2]
            depth = torch.tensor([mol_tree.nodes[idx].depth for idx in mol_tree.node_order_in_graph])
            level = depth.max() - depth

            mol_idx = raw_data[4]
            data = Data(x = X, edge_index = edge_index, edge_attr = edge_attr, y = y, n_nodes = num_nodes, level = level, mol_idx = torch.tensor([mol_idx]))

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)
        torch.save(self.collate(data_list), self.processed_paths[0])
        

class ZINCDataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.cfg = cfg
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)
        self.vocab_size = 780

        datasets = {'train': ZINCDataset(dataset_name=self.cfg.dataset.name,
                                                 split='train', root=root_path),
                    'val': ZINCDataset(dataset_name=self.cfg.dataset.name,
                                        split='val', root=root_path),
                    'test': ZINCDataset(dataset_name=self.cfg.dataset.name,
                                        split='test', root=root_path)}
        if cfg.general.eval_size:
            datasets["test"] = datasets["test"][:cfg.general.eval_size]

        if cfg.general.target is None:
            # remove y
            for split in ["train", "val", "test"]:
                new_dataset = []
                for data in datasets[split]:
                    data.y = torch.zeros([1, 0]).float()
                    new_dataset.append(data)
                datasets[split] = new_dataset
        else:
            t_idx = 0 if cfg.general.target == "plogp" else 1
            for split in ["train", "val", "test"]:
                new_dataset = []
                for data in datasets[split]:
                    data.y = data.y[:, t_idx].unsqueeze(1)
                    new_dataset.append(data)
                datasets[split] = new_dataset

        print(f'Dataset sizes: train {len(datasets["train"])}, val {len(datasets["val"])}, test {len(datasets["test"])}')

        super().__init__(cfg, datasets)
    

def get_level_distribution(datamodule, max_n_nodes):
    level_num = torch.zeros(max_n_nodes)
    level_size = torch.zeros(max_n_nodes, max_n_nodes)
    max_level_num = -1
    max_level_size = -1
    for loader in [datamodule.train_dataloader(), datamodule.val_dataloader()]:
        for data in loader:
            _, node_mask, level = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch, data.level)

            if level.max() > max_level_num: max_level_num = level.max()
            l, cnts = torch.unique(level.max(dim = -1)[0], return_counts = True)
            level_num[l] += cnts

            for i in range(level.shape[0]):
                l, cnts = torch.unique(level[i][node_mask[i]], return_counts = True)
                level_size[l, cnts] += 1
                if cnts.max() > max_level_size: max_level_size = cnts.max()
    level_num = level_num[:max_level_num + 1] / level_num.sum()
    level_size = level_size[:, :max_level_size + 1] / level_size.sum(dim = -1, keepdim = True)
    return level_num, level_size


class ZINCInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, dataset_config, level_data = False):
        self.name = 'ZINC'
        dataset_meta = torch.load(f"../../../{dataset_config.datadir}/meta.pth")

        self.n_nodes = datamodule.node_counts(max_nodes_possible = int(dataset_meta["num_nodes"][-1]) + 2)
        
        self.node_types = dataset_meta["node_type"]
        self.edge_types = dataset_meta["edge_type"]

        if level_data:
            self.level_num, self.level_size = get_level_distribution(datamodule, int(dataset_meta["num_nodes"][-1]) + 2)
        super().complete_infos(self.n_nodes, self.node_types)

