import argparse
import os

import numpy as np
import pandas as pd
import torch
import yaml


def masked_mae(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan, return_used_sample_num=False):
    """Masked mean absolute error.

    Args:
        return_used_sample_num:
        preds (torch.Tensor): predicted values
        labels (torch.Tensor): labels
        null_val (float, optional): null value. Defaults to np.nan.

    Returns:
        torch.Tensor: masked mean absolute error
    """

    if np.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        eps = 5e-5
        # mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.)
        mask = labels > null_val
    used_sample_num = mask.sum()
    mask = mask.float()
    mask /= torch.mean((mask))
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.abs(preds - labels)
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    if return_used_sample_num:
        return torch.mean(loss), used_sample_num
    else:
        return torch.mean(loss)


def masked_mape(prediction: torch.Tensor, target: torch.Tensor, null_val: float = 0.0) -> torch.Tensor:
    """Masked mean absolute percentage error.

    Args:
        prediction (torch.Tensor): predicted values
        target (torch.Tensor): labels
        null_val (float, optional): null value.
                                    In the mape metric, null_val is set to 0.0 by all default.
                                    We keep this parameter for consistency, but we do not allow it to be changed.

    Returns:
        torch.Tensor: masked mean absolute percentage error
    """
    # we do not allow null_val to be changed
    null_val = 0.0
    # delete small values to avoid abnormal results
    # TODO: support multiple null values
    target = torch.where(torch.abs(target) < 1e-4, torch.zeros_like(target), target)
    if np.isnan(null_val):
        mask = ~torch.isnan(target)
    else:
        eps = 5e-5
        mask = ~torch.isclose(target, torch.tensor(null_val).expand_as(target).to(target.device), atol=eps, rtol=0.)
    mask = mask.float()
    mask /= torch.mean((mask))
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.abs(torch.abs(prediction - target) / target)
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


def masked_mse(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
    """Masked mean squared error.

    Args:
        preds (torch.Tensor): predicted values
        labels (torch.Tensor): labels
        null_val (float, optional): null value. Defaults to np.nan.

    Returns:
        torch.Tensor: masked mean squared error
    """

    if np.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        eps = 5e-5
        mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.)
        # mask = labels > null_val
    mask = mask.float()
    mask /= torch.mean((mask))
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.square(torch.sub(preds, labels))
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


def masked_rmse(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
    """root mean squared error.

    Args:
        preds (torch.Tensor): predicted values
        labels (torch.Tensor): labels
        null_val (float, optional): null value . Defaults to np.nan.

    Returns:
        torch.Tensor: root mean squared error
    """

    return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val))


def masked_wape(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
    """Masked weighted absolute percentage error (WAPE)

    Args:
        preds (torch.Tensor): predicted values
        labels (torch.Tensor): labels
        null_val (float, optional): null value. Defaults to np.nan.

    Returns:
        torch.Tensor: masked mean absolute error
    """

    if np.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        eps = 5e-5
        mask = ~torch.isclose(labels, torch.tensor(null_val).expand_as(labels).to(labels.device), atol=eps, rtol=0.)
    mask = mask.float()
    preds, labels = preds * mask, labels * mask
    loss = torch.sum(torch.abs(preds - labels)) / torch.sum(torch.abs(labels))
    return torch.mean(loss)


def eval_metrics(preds: torch.Tensor, labels: torch.Tensor, null_val: float = np.nan):
    mae = masked_mae(preds, labels, null_val=null_val)
    rmse = masked_rmse(preds, labels, null_val=null_val)
    mape = masked_mape(preds, labels) * 100
    wape = masked_wape(preds, labels) * 100
    return mae, rmse, mape, wape


def report_trial_result(trial_dir, reported_param=None):
    special_params, special_name = [], ''

    if reported_param is not None:
        trial_param_path = os.path.join(trial_dir, 'seed_0', 'hparams.yaml')
        with open(trial_param_path, 'r') as f:
            hparams = yaml.safe_load(f.read())
        for hparam_name in reported_param:
            special_params.append(hparams[hparam_name])

    results = []
    for version_name in os.listdir(trial_dir):
        if 'seed' in version_name:
            test_outputs = np.load(os.path.join(trial_dir, version_name, 'test_outputs.npz'))
            prediction = test_outputs['prediction']
            label = test_outputs['label']
            time_marker = test_outputs['time_marker'][:, :, 0, :]
            tod = time_marker[..., 0]
            eod = time_marker[..., 4]
            peak_mask = ((tod > 7 * 6 - 1) & (tod < 10 * 6)) | ((tod > 17 * 6 - 1) & (tod < 20 * 6))
            event_mask = eod > 0
            event_peak_mask = peak_mask & event_mask

            prediction = torch.tensor(prediction)
            label = torch.tensor(label)
            event_peak_mask = torch.tensor(event_peak_mask, dtype=torch.bool).unsqueeze(-1)
            overall_mae, overall_rmse, overall_mape, overall_wape = eval_metrics(prediction, label, null_val=0.0)
            ep_mae, ep_rmse, ep_mape, ep_wape = eval_metrics(prediction, label * event_peak_mask, null_val=0.0)
            results.append([overall_mae, overall_rmse, overall_mape, overall_wape, ep_mae, ep_rmse, ep_mape, ep_wape])

    result_df = pd.DataFrame(results,
                             columns=['all_mae', 'all_rmse', 'all_mape', 'all_wape', 'ep_mae', 'ep_rmse', 'ep_mape',
                                      'ep_wape'], index=None)

    result_df.loc['mean', 'all_mae'] = result_df.all_mae.mean()
    result_df.loc['mean', 'all_rmse'] = result_df.all_rmse.mean()
    result_df.loc['mean', 'all_mape'] = result_df.all_mape.mean()
    result_df.loc['mean', 'all_wape'] = result_df.all_wape.mean()
    result_df.loc['mean', 'ep_mae'] = result_df.ep_mae.mean()
    result_df.loc['mean', 'ep_rmse'] = result_df.ep_rmse.mean()
    result_df.loc['mean', 'ep_mape'] = result_df.ep_mape.mean()
    result_df.loc['mean', 'ep_wape'] = result_df.ep_wape.mean()

    # Todo: std 计算有误，这里计算的是每个seed的metric和avg metric的std

    # result_df.loc['std', 'all_mae'] = result_df.all_mae.std()
    # result_df.loc['std', 'all_rmse'] = result_df.all_rmse.std()
    # result_df.loc['std', 'all_mape'] = result_df.all_mape.std()
    # result_df.loc['std', 'all_wape'] = result_df.all_wape.std()
    # result_df.loc['std', 'ep_mae'] = result_df.ep_mae.std()
    # result_df.loc['std', 'ep_rmse'] = result_df.ep_rmse.std()
    # result_df.loc['std', 'ep_mape'] = result_df.ep_mape.std()
    # result_df.loc['std', 'ep_wape'] = result_df.ep_wape.std()

    return result_df, special_params


def report_exp_abstract(exp_dir, reported_param):
    abstract = []
    abstract_header = ['config_hash'] + reported_param + ['all_mae', 'all_rmse', 'all_mape', 'all_wape', 'ep_mae',
                                                          'ep_rmse', 'ep_mape', 'ep_wape']

    for config_hash in os.listdir(exp_dir):
        config_cp_dir = os.path.join(exp_dir, config_hash)
        if not os.path.isdir(config_cp_dir):
            continue

        result_df, special_params = report_trial_result(config_cp_dir, reported_param)
        metric_mean_std = []
        for metric in ['all_mae', 'all_rmse', 'all_mape', 'all_wape', 'ep_mae', 'ep_rmse', 'ep_mape', 'ep_wape']:
            # metric_mean_std.append('{:.2f}+{:.3f}'.format(result_df.loc['mean'][metric], result_df.loc['std'][metric]))
            metric_mean_std.append('{:.2f}'.format(result_df.loc['mean'][metric]))
        abstract.append([config_hash] + special_params + metric_mean_std)

    abstract_df = pd.DataFrame(abstract, columns=abstract_header)
    save_path = os.path.join(exp_dir, 'abstract.csv')
    abstract_df.to_csv(save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_dir", default=None, type=str, help="config save dir")
    args = parser.parse_args()

    if args.exp_dir is not None:
        report_exp_abstract(
            exp_dir=args.exp_dir,
            # reported_param=["lr"],
            reported_param=["lr"]
        )
