# -*- coding:utf-8 -*-

import numpy as np
import torch
import os

def save_model(model, model_dir, seq_length, exp_id=None):
    if model_dir is None:
        return
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    exp_id = str(exp_id) if exp_id else ''
    file_name = os.path.join(model_dir, str(seq_length) + '_rul' + exp_id + '.pt')
    with open(file_name, 'wb') as f:
        torch.save(model, f)


def load_model(model_dir, seq_length, epoch, device):
    if not model_dir:
        return
    epoch = str(epoch) if epoch else ''
    # file_name = os.path.join(model_dir, epoch + '_rul' + str(seq_length) + '.pt')
    file_name = os.path.join(model_dir, str(seq_length) + '_rul' + '.pt')

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    if not os.path.exists(file_name):
        return
    with open(file_name, 'rb') as f:
        model = torch.load(f, map_location=device)

    return model


def masked_MAPE(v, v_, axis=None):
    '''
    Mean absolute percentage error.
    :param v: np.ndarray or int, ground truth.
    :param v_: np.ndarray or int, prediction.
    :param axis: axis to do calculation.
    :return: int, MAPE averages on all elements of input.
    '''
    mask = (v == 0)
    percentage = np.abs(v_ - v) / np.abs(v)
    if np.any(mask):
        masked_array = np.ma.masked_array(percentage, mask=mask)  # mask the dividing-zero as invalid
        result = masked_array.mean(axis=axis)
        if isinstance(result, np.ma.MaskedArray):
            return result.filled(np.nan)
        else:
            return result
    return np.mean(percentage, axis).astype(np.float64)


def MAPE(v, v_, axis=None):
    '''
    Mean absolute percentage error.
    :param v: np.ndarray or int, ground truth.
    :param v_: np.ndarray or int, prediction.
    :param axis: axis to do calculation.
    :return: int, MAPE averages on all elements of input.
    '''
    mape = np.abs(v_ - v) / (np.abs(v) + 1e-5).astype(np.float64)
    mape = np.where(mape > 5, 5, mape)
    return np.mean(mape, axis)



def RMSE(v, v_, axis=None):
    '''
    Mean squared error.
    :param v: np.ndarray or int, ground truth.
    :param v_: np.ndarray or int, prediction.
    :param axis: axis to do calculation.
    :return: int, RMSE averages on all elements of input.
    '''
    return np.sqrt(np.mean((v_ - v) ** 2, axis)).astype(np.float64)


def Score(v, v_):
    h_i = v_ - v
    for i in range(len(h_i)):
        if h_i[i] < 0:
            h_i[i] = np.exp(-h_i[i] / 13) - 1
        else:
            h_i[i] = np.exp(h_i[i] / 10) - 1
    score = sum(h_i)
    return score.item()


def evaluate(y, y_hat, by_step=False, by_node=False):
    if not by_step and not by_node:
        return MAPE(y, y_hat), Score(y, y_hat), RMSE(y, y_hat)
