import os
import csv

import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.graphgym.config import cfg


def parse(features):
    if features == 'none':
        return []
    return features.split(',')


class LPGDataset(InMemoryDataset):
    node_encoding_functions = {}
    edge_encoding_functions = {}

    def __init__(self,
                 root,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
        for name in parse(cfg.dataset.node_features):
            assert name in self.node_encoding_functions, f'Invalid node feature name: {name}.'
        for name in parse(cfg.dataset.edge_features):
            assert name in self.edge_encoding_functions, f'Invalid node feature name: {name}.'
        if cfg.dataset.task == 'node':
            assert cfg.dataset.node_prediction_target in self.node_encoding_functions, \
                f'Invalid node prediction target {cfg.dataset.node_prediction_target}'
        elif cfg.dataset.task == 'edge':
            assert cfg.dataset.edge_prediction_target in self.edge_encoding_functions, \
                f'Invalid edge prediction target {cfg.dataset.edge_prediction_target}'
        else:
            raise ValueError(
                f'cfg.dataset.task has to node or edge. {cfg.dataset.task} is given.'
            )

    @property
    def raw_file_names(self):
        return ['nodes.csv', 'edges.csv']

    @property
    def processed_file_names(self):
        file_name = "data"
        file_name += "__node"
        for name in parse(cfg.dataset.node_features):
            file_name += f'_{name}'
        file_name += "__edge"
        for name in parse(cfg.dataset.edge_features):
            file_name += f'_{name}'
        if cfg.dataset.task == 'node':
            file_name += f'__nodetarget_{cfg.dataset.node_prediction_target}'
        else:
            file_name += f'__edgetarget_{cfg.dataset.edge_prediction_target}'
        file_name += '.pt'
        return file_name

    def process(self):
        def read_csv(file_path, required_fields):
            items = []
            with open(file_path, newline='', encoding='latin-1') as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',', quotechar='"')
                headers = next(csvreader)
                for row in csvreader:
                    item = {
                        field: row[idx]
                        for (idx, field) in enumerate(headers)
                        if field in required_fields
                    }
                    items.append(item)
            return items

        def construct_feature_vectors(items, feature_names,
                                      encoding_functions):
            feature_vectors = torch.ones(len(items), 1)
            for name in feature_names:
                features = torch.tensor(encoding_functions[name](items))
                feature_vectors = torch.hstack([feature_vectors, features])
            return feature_vectors

        node_file_name, edge_file_name = self.raw_file_names

        required_fields = parse(cfg.dataset.node_features)
        if cfg.dataset.task == 'node':
            required_fields.append(cfg.dataset.node_prediction_target)
        nodes = read_csv(os.path.join(self.raw_dir, node_file_name), required_fields)
        x = construct_feature_vectors(nodes, parse(cfg.dataset.node_features), self.node_encoding_functions)
        prediction_target = []
        if cfg.dataset.task == 'node':
            # node-level prediction
            prediction_target = self.node_encoding_functions[
                cfg.dataset.node_prediction_target](nodes)
        del nodes

        required_fields = ["start", "end"] + parse(cfg.dataset.edge_features)
        if cfg.dataset.task == 'edge':
            required_fields.append(cfg.dataset.edge_prediction_target)
        edges = read_csv(os.path.join(self.raw_dir, edge_file_name), required_fields)
        edge_attr = construct_feature_vectors(edges, parse(cfg.dataset.edge_features), self.edge_encoding_functions)
        if cfg.dataset.task == 'edge':
            prediction_target = self.edge_encoding_functions[
                cfg.dataset.edge_prediction_target](edges)

        # Set up edge index
        edge_starts = [int(edge["start"]) for edge in edges]
        edge_ends = [int(edge["end"]) for edge in edges]
        edge_index = torch.tensor([edge_starts, edge_ends], dtype=torch.long)
        del edges, edge_starts, edge_ends

        split = cfg.dataset.split
        probs = torch.rand(x.shape[0]) if cfg.dataset.task == 'node' else torch.rand(edge_attr.shape[0])
        assert len(split) in [2, 3], \
            f'len(cfg.dataset.split) needs to be 2 or 3. {split} is given.'
        assert sum(split) == 1, \
            f'sum of ratios has to be 1. {sum(split)} is given.'
        train_mask = probs < split[0]
        if len(split) == 3:
            # train, val, test
            val_mask = probs >= sum(split[:2])
            test_mask = (split[0] <= probs) * (probs < sum(split[:2]))
        else:
            # train, val
            val_mask = probs >= split[0]

        task_type = cfg.dataset.task_type
        if task_type == 'regression':
            label_dtype = torch.float
        elif task_type in ["classification_binary", "classification_multilabel"]:
            label_dtype = torch.bool
        else:
            # classification_multi
            label_dtype = torch.long

        if cfg.dataset.task == 'node':
            # node-level prediction
            kwargs = dict(
                y=torch.tensor(prediction_target, dtype=label_dtype),
                train_mask=train_mask,
                val_mask=val_mask)
            if len(split) == 3:
                kwargs['test_mask'] = test_mask
        else:
            # edge-level prediction
            edge_label = torch.tensor(prediction_target, dtype=label_dtype)
            kwargs = dict(
                train_edge_index=edge_index[:, train_mask],
                train_edge_label=edge_label[train_mask],
                val_edge_index=edge_index[:, val_mask],
                val_edge_label=edge_label[val_mask])
            if len(split) == 3:
                kwargs['test_edge_index'] = edge_index[:, test_mask]
                kwargs['test_edge_label'] = edge_label[test_mask]
        del prediction_target

        data = Data(x=x,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    **kwargs)
        data_list = [data]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # Store data
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
