import math

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.metrics import mean_squared_log_error


def flatten_extra_dims(quant):
    return quant.flatten()

def r2(y_true, y_pred):
    y_true = flatten_extra_dims(y_true)
    y_pred = flatten_extra_dims(y_pred)
    return r2_score(y_true, y_pred).mean()


def rmsle(y_true, y_pred):
    assert len(y_true) == len(y_pred)
    y_true = flatten_extra_dims(y_true)
    y_pred = flatten_extra_dims(y_pred)
    terms_to_sum = (np.log(abs(y_pred) + 1) - np.log(y_true + 1)) ** 2.0 
    return (sum(terms_to_sum) * (1.0/len(y_true))) ** 0.5


def rmse(y_true, y_pred):
    y_true = flatten_extra_dims(y_true)
    y_pred = flatten_extra_dims(y_pred)
    return np.sqrt(((y_pred - y_true) ** 2).mean())


def mae(y_true, y_pred):
    y_true = flatten_extra_dims(y_true)
    y_pred = flatten_extra_dims(y_pred)
    return mean_absolute_error(y_true, y_pred).mean()


def smape(y_true, y_pred):
    y_true = flatten_extra_dims(y_true)
    y_pred = flatten_extra_dims(y_pred)
    return (100.0/ len(y_true) * np.sum(2.0 * np.abs(y_pred - y_true) / \
           (np.abs(y_true) + np.abs(y_pred) + 0.00001))).mean()


def sc(signal):
    signal = flatten_extra_dims(signal)
    return np.sum(abs(signal[1:] - signal[:-1]))


def smape_vs_sc(y_true, y_pred, window):
    y_true = flatten_extra_dims(y_true)
    y_pred = flatten_extra_dims(y_pred)
    smape_vs_sc_all_windows = []

    for i in range(0, y_true.shape[0]):
        if i + window + 1 < y_true.shape[0]:
            smape_val = smape(y_true[i: i + window], y_pred[i: i + window])
            sc_val = sc(y_true[i : i + window])
            smape_vs_sc_all_windows.append([smape_val, sc_val])

    return np.asarray(smape_vs_sc_all_windows)


def sc_mse(y_pred, y_true):
    sc_y_true = torch.sum(torch.abs(y_true[:,:,1:] - y_true[:,:,:-1]), dim=2)
    mse = torch.mean((y_pred - y_true) ** 2.0, dim=2)
    loss = sc_y_true * mse
    loss = torch.mean(loss)
    return loss


def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True, reduce=True):
    if not (target.size() == input.size()):
        raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))

    max_val = (-input).clamp(min=0)
    loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()

    if weight is not None:
        loss = loss * weight

    if not reduce:
        return loss
    elif size_average:
        return loss.mean()
    else:
        return loss.sum()


class CEExtended(nn.Module):
    def __init__(self):
        super(CEExtended, self).__init__()
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, y_pred, y_true):
        y_pred = y_pred.permute(0, 2, 1).reshape(-1, y_pred.shape[1]).float()
        y_true = y_true.flatten().long()

        return self.criterion(y_pred, y_true)


class GaussianKernelMMD(nn.Module):
    def __init__(self, sigma):
        super(GaussianKernelMMD, self).__init__()
        self.sigma = sigma

    def forward(self, y_pred, y_true):
        xx_base = (y_pred**2).sum(axis=1)
        xx_new = (y_true**2).sum(axis=1)
        y_pred = y_pred.flatten()
        y_true = y_true.flatten()

        xx = y_true.dot(y_pred.T)
        return torch.exp(-1.0 / (2 * self.sigma**2) * (-2 * xx + xx_base + xx_new)).mean()