import numpy as np
import torch


def mse_with_mask_torch(x, x_hat, mask):
    return torch.pow((x - x_hat) * mask, 2).sum() / (mask.sum())


def mse_with_mask(x, x_hat, mask):
    return np.power((x - x_hat) * mask, 2).sum() / (mask.sum())


def mae_with_mask(x, x_hat, mask):
    return np.abs((x - x_hat) * mask).sum() / mask.sum()


def rmse_with_mask(x, x_hat, mask):
    return np.sqrt(np.power((x - x_hat) * mask, 2).sum() / mask.sum())


def obtain_forecast_task(loader, model, head, metrics, device):
    model.eval(), head.eval()

    turth, preds, masks = [], [], []
    with torch.no_grad():
        for batch in loader:
            x_time, x, x_mask, y_time, y, y_mask = (item.to(device) for item in batch)
            _, _, _, _, repr = model(torch.cat([x, x_mask], dim=-1), x_time, None)
            pred = head(repr, y_time)

            turth.append(y.detach().cpu().numpy())
            preds.append(pred.detach().cpu().numpy())
            masks.append(y_mask.detach().cpu().numpy())
        # end for batch
    model.train(), head.train()

    truth, preds, masks = np.concatenate(turth, axis=0), np.concatenate(preds, axis=0), np.concatenate(masks, axis=0)
    results = {}
    for name, metric in metrics.items():
        results[name] = metric(truth, preds, masks)
    return results

