# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# The file is modified based on the original Graphormer's source code.

from functools import partial

import ogb
import ogb.graphproppred
import ogb.lsc
from pytorch_lightning import LightningDataModule
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
from torch.utils.data import DataLoader

from collator import collator
from wrapper import *

dataset = None


def my_acc_calculator(input_dict):
    y_true = input_dict['y_true']
    y_pred = input_dict['y_pred']

    y_pred = torch.argmax(y_pred, dim=-1, keepdim=True)
    return {'acc': accuracy_score(y_true.int(), y_pred.int())}


def my_QM9_evaluator(input_dict):
    y_true = input_dict['y_true'].view(-1, 12)
    y_pred = input_dict['y_pred'].view(-1, 12)

    mae = torch.abs(y_pred - y_true)
    mae = torch.mean(mae, dim=0)
    return {'mae': torch.sum(mae).item()}


def my_QM8_evaluator(input_dict):
    y_true = input_dict['y_true'].view(-1, 16)
    y_pred = input_dict['y_pred'].view(-1, 16)

    mae = torch.abs(y_pred - y_true)
    mae = torch.mean(mae, dim=-1)
    return {'mae': torch.mean(mae).item()}


def my_MoleculeNet_evaluator(input_dict):
    y_true = input_dict['y_true']
    y_pred = input_dict['y_pred']

    mae = torch.pow(y_pred - y_true, 2)
    return {'rmse': torch.sqrt(torch.mean(mae)).item()}


def get_dataset(dataset_name='abaaba'):
    global dataset
    if dataset is not None:
        return dataset

    if dataset_name == 'esol':
        dataset = {
            'num_class': 1,
            'loss_fn': F.mse_loss,
            'metric': 'rmse',
            'metric_mode': 'min',
            'evaluator': my_MoleculeNet_evaluator,
            'dataset': MyMoleculeNetDataset(name='ESOL', root='./dataset'),
            'max_node': 128,
        }
    elif dataset_name == 'QM9':
        dataset = {
            'num_class': 12,
            'loss_fn': F.l1_loss,
            'metric': 'mae',
            'metric_mode': 'min',
            'evaluator': my_QM9_evaluator,
            'dataset': MyQM9Dataset(root='./dataset/qm9'),
            'max_node': 128,
        }
    elif dataset_name == 'QM8':
        dataset = {
            'num_class': 16,
            'loss_fn': F.l1_loss,
            'metric': 'mae',
            'metric_mode': 'min',
            'evaluator': my_QM8_evaluator,
            'dataset': MyQM8Dataset(),
            'max_node': 128,
        }
    elif dataset_name == 'ZINC':
        dataset = {
            'num_class': 1,
            'loss_fn': F.l1_loss,
            'metric': 'mae',
            'metric_mode': 'min',
            'evaluator': ogb.lsc.PCQM4MEvaluator(),
            'train_dataset': MyZINCDataset(subset=True, root='./dataset/pyg_zinc', split='train'),
            'valid_dataset': MyZINCDataset(subset=True, root='./dataset/pyg_zinc', split='val'),
            'test_dataset': MyZINCDataset(subset=True, root='./dataset/pyg_zinc', split='test'),
            'max_node': 128,
        }
    elif dataset_name == 'PTC-MR':
        dataset = {
            'num_class': 2,
            'loss_fn': F.binary_cross_entropy_with_logits,
            'metric': 'acc',
            'metric_mode': 'max',
            'evaluator': my_acc_calculator,
            'dataset': MyTUDataset(root='./dataset', name='PTC_MR'),
            'max_node': 128,
        }
    elif dataset_name == 'COX2':
        dataset = {
            'num_class': 2,
            'loss_fn': F.binary_cross_entropy_with_logits,
            'metric': 'acc',
            'metric_mode': 'max',
            'evaluator': my_acc_calculator,
            'dataset': MyTUDataset(root='./dataset', name='COX2'),
            'max_node': 128,
        }
    elif dataset_name == 'PROTEINS':
        dataset = {
            'num_class': 2,
            'loss_fn': F.binary_cross_entropy_with_logits,
            'metric': 'acc',
            'metric_mode': 'max',
            'evaluator': my_acc_calculator,
            'dataset': MyTUDataset(root='./dataset', name='PROTEINS'),
            'max_node': 128,
        }
    elif dataset_name == 'MUTAG':
        dataset = {
            'num_class': 2,
            'loss_fn': F.binary_cross_entropy_with_logits,
            'metric': 'acc',
            'metric_mode': 'max',
            'evaluator': my_acc_calculator,
            'dataset': MyTUDataset(root='./dataset', name='MUTAG'),
            'max_node': 128,
        }
    else:
        raise NotImplementedError

    if dataset_name in ['esol']:
        y = dataset['dataset'].data.y
        y_mean = torch.mean(y, dim=0).view(1, -1)
        y_std = torch.std(y, dim=0).view(1, -1)
        dataset['dataset'].data.y = (y - y_mean) / y_std

    if dataset_name in ['QM9']:
        y = dataset['dataset'].data.y[:, :12]
        y_mean = torch.mean(y, dim=0).view(1, -1)
        y_std = torch.std(y, dim=0).view(1, -1)
        dataset['dataset'].data.y = (y - y_mean) / y_std

    if dataset_name in ['PTC-MR', 'COX2', 'PROTEINS', 'MUTAG']:
        dataset['dataset'].data.y = dataset['dataset'].data.y.float()

    print(f' > {dataset_name} loaded!')
    print(dataset)
    print(f' > dataset info ends')

    return dataset


class GraphDataModule(LightningDataModule):
    def __init__(
            self,
            k: int = 0,
            dataset_name: str = 'esol',
            num_workers: int = 0,
            batch_size: int = 256,
            seed: int = 2,
            multi_hop_max_dist: int = 5,
            spatial_pos_max: int = 1024,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.dataset_name = dataset_name
        self.dataset = get_dataset(self.dataset_name)
        self.seed = seed

        self.num_workers = num_workers
        self.batch_size = batch_size
        self.dataset_train = ...
        self.dataset_val = ...
        self.dataset_test = ...
        self.multi_hop_max_dist = multi_hop_max_dist
        self.spatial_pos_max = spatial_pos_max

        self.k = k

    def setup(self, stage: str = None):
        if self.dataset_name in ['ZINC']:
            self.dataset_train = self.dataset['train_dataset']
            self.dataset_val = self.dataset['valid_dataset']
            self.dataset_test = self.dataset['test_dataset']
        elif self.dataset_name in ['esol', 'QM8', 'QM9']:
            num_nodes = len(self.dataset['dataset'])
            np.random.seed(self.seed)
            permute = np.random.permutation(num_nodes)
            train_ratio = 0.8
            val_ratio = 0.1
            test_ratio = 1 - train_ratio - val_ratio
            train_idx = permute[: int(train_ratio * num_nodes)]
            val_idx = permute[int(train_ratio * num_nodes): int((train_ratio + val_ratio) * num_nodes)]
            test_idx = permute[int(1 - test_ratio * num_nodes):]
            self.dataset_train = self.dataset['dataset'][train_idx]
            self.dataset_val = self.dataset['dataset'][val_idx]
            self.dataset_test = self.dataset['dataset'][test_idx]
        elif self.dataset_name in ['PTC-MR', 'COX2', 'PROTEINS', 'MUTAG']:
            num_nodes = len(self.dataset['dataset'])
            np.random.seed(self.seed)
            permute = np.random.permutation(num_nodes)
            idx_10_list = []
            for i in range(9):
                idx_10_list.append(permute[int(i * 0.1 * num_nodes):int((i + 1) * 0.1 * num_nodes)])
            idx_10_list.append(permute[int(1 - 0.1 * num_nodes):])

            test_idx = idx_10_list[self.k]
            if self.k == 0:
                train_idx = np.concatenate(idx_10_list[1:])
            elif self.k == 9:
                train_idx = np.concatenate(idx_10_list[:-1])
            else:
                train_idx = np.concatenate(idx_10_list[:self.k] + idx_10_list[(self.k + 1):])
            self.dataset_train = self.dataset['dataset'][train_idx]
            self.dataset_test = self.dataset['dataset'][test_idx]
        else:
            raise NotImplemented

    def train_dataloader(self):
        loader = DataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
            collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
                'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
        )
        print('len(train_dataloader)', len(loader))
        return loader

    def val_dataloader(self):
        if self.dataset_name in ['PTC-MR', 'COX2', 'PROTEINS', 'MUTAG']:
            loader = DataLoader(
                self.dataset_test,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers,
                persistent_workers=True,
                pin_memory=False,
                collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
                    'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
            )
        else:
            loader = DataLoader(
                self.dataset_val,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.num_workers,
                persistent_workers=True,
                pin_memory=False,
                collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
                    'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
            )
        print('len(val_dataloader)', len(loader))
        return loader

    def test_dataloader(self):
        loader = DataLoader(
            self.dataset_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=False,
            collate_fn=partial(collator, max_node=get_dataset(self.dataset_name)[
                'max_node'], multi_hop_max_dist=self.multi_hop_max_dist, spatial_pos_max=self.spatial_pos_max),
        )
        print('len(test_dataloader)', len(loader))
        return loader
