import math
import torch
from torch_geometric.graphgym.register import register_metric
from torch_geometric.graphgym.config import cfg


def get_pred_int(pred_score):
    if len(pred_score.shape) == 1 or pred_score.shape[1] == 1:
        return (pred_score > cfg.model.thresh).long()
    else:
        return pred_score.max(dim=1)[1]


@register_metric('classification_binary_with_ignore_index')
def classification_binary_with_ignore_index(true, pred, task_type):
    from sklearn.metrics import (
        accuracy_score,
        f1_score,
        precision_score,
        recall_score,
        roc_auc_score,
    )

    true, pred_score = torch.cat(true), torch.cat(pred)
    pred_int = get_pred_int(pred_score)
    ignore_index = cfg.dataset.ignore_index
    mask = true == ignore_index
    true, pred_score, pred_int = true[~mask], pred_score[~mask], pred_int[~mask]
    try:
        r_a_score = roc_auc_score(true, pred_score)
    except ValueError:
        r_a_score = 0.0
    return {
        'accuracy': round(accuracy_score(true, pred_int), cfg.round),
        'precision': round(precision_score(true, pred_int), cfg.round),
        'recall': round(recall_score(true, pred_int), cfg.round),
        'f1': round(f1_score(true, pred_int), cfg.round),
        'auc': round(r_a_score, cfg.round),
    }


@register_metric('classification_multi_with_ignore_index')
def classification_multi_with_ignore_index(true, pred, task_type):
    from sklearn.metrics import accuracy_score

    true, pred_score = torch.cat(true), torch.cat(pred)
    pred_int = get_pred_int(pred_score)
    ignore_index = cfg.dataset.ignore_index
    mask = true == ignore_index
    return {'accuracy': round(accuracy_score(true[~mask], pred_int[~mask.squeeze(-1)]), cfg.round)}


@register_metric('classification_multilabel_with_ignore_index')
def classification_multilabel_with_ignore_index(true, pred, task_type):
    assert task_type == 'classification_multilabel', f'task_type has to be classification_multilabel. ' \
                                                     f'{task_type} is given.'
    from sklearn.metrics import accuracy_score
    true, pred_score = torch.cat(true), torch.cat(pred)
    mask = true.sum(dim=-1) == 0  # ignore examples that have a zero true vector
    pred_int = (pred_score > cfg.model.thresh).long()
    return {'accuracy': round(accuracy_score(true[~mask], pred_int[~mask]), cfg.round)}


@register_metric('regression_with_ignore_index')
def regression_with_ignore_index(true, pred, task_type):
    from sklearn.metrics import mean_absolute_error, mean_squared_error

    true, pred = torch.cat(true), torch.cat(pred)
    ignore_index = cfg.dataset.ignore_index
    mask = true == ignore_index
    true, pred = true[~mask], pred[~mask]
    return {
        'mae':
            float(round(mean_absolute_error(true, pred), cfg.round)),
        'mse':
            float(round(mean_squared_error(true, pred), cfg.round)),
        'rmse':
            float(round(math.sqrt(mean_squared_error(true, pred)), cfg.round))
    }


