import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions.normal as norm
import time
from _base import *


class QuantileLoss(torch.nn.Module):
    """
    Quantile regression loss
    """

    def __init__(self):
        super(QuantileLoss, self).__init__()

    def forward(self, yhat, y, tau):
        diff = yhat - y
        mask = (diff.ge(0).float() - tau).detach()
        return (mask * diff).mean()


class SQR(torch.nn.Module):
    def __init__(self, n_feature, n_hidden1, n_hidden2, n_output, alpha=0.1, dropout_prob=0):
        super(SQR, self).__init__()

        self.loss_function = QuantileLoss()
        self.alpha = alpha
        self.hidden1 = torch.nn.Linear(n_feature + 1, n_hidden1)  # hidden layer
        self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2)  # hidden layer
        self.pred = torch.nn.Linear(n_hidden2, n_output)  # output layer
        self.dropout = torch.nn.Dropout(dropout_prob)

        torch.nn.init.xavier_uniform_(self.hidden1.weight)
        torch.nn.init.xavier_uniform_(self.hidden2.weight)
        torch.nn.init.xavier_uniform_(self.pred.weight)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.dropout(x)
        x = self.pred(x)  # linear output

        return x

    def predict(self, x):
        tau_l = torch.zeros(x.size(0), 1) + self.alpha / 2
        tau_u = torch.zeros(x.size(0), 1) + (1 - self.alpha / 2)

        tau_l = tau_l.cuda()
        tau_u = tau_u.cuda()

        preds_l = self.forward(
            torch.cat((x, (tau_l - 0.5) * 12), 1)).detach()
        preds_u = self.forward(
            torch.cat((x, (tau_u - 0.5) * 12), 1)).detach()

        return (preds_l + preds_u) / 2, preds_l, preds_u

    def loss(self, x, y):
        tau_l = torch.zeros(x.size(0), 1) + self.alpha / 2
        tau_u = torch.zeros(x.size(0), 1) + (1 - self.alpha / 2)

        tau_l = tau_l.cuda()
        tau_u = tau_u.cuda()

        preds_l = self.forward(torch.cat((x, (tau_l - 0.5) * 12), 1))
        preds_u = self.forward(torch.cat((x, (tau_u - 0.5) * 12), 1))

        return self.loss_function(preds_l, y, tau_l) + self.loss_function(preds_u, y, tau_u)


class GMM(torch.nn.Module):
    def __init__(self, n_feature, n_hidden1, n_hidden2, n_gmm, dropout_prob=0):
        super(GMM, self).__init__()
        self.hidden1 = torch.nn.Linear(n_feature, n_hidden1)
        self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2)
        self.output = torch.nn.Linear(n_hidden2, n_gmm * 3)
        self.dropout = torch.nn.Dropout(dropout_prob)
        self.n_gmm = n_gmm

        torch.nn.init.xavier_uniform_(self.hidden1.weight)
        torch.nn.init.xavier_uniform_(self.hidden2.weight)
        torch.nn.init.xavier_uniform_(self.output.weight)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.dropout(x)
        x = self.output(x)

        m = x[:, :self.n_gmm]
        s = x[:, self.n_gmm:2*self.n_gmm]
        s = s ** 2  # force positive
        p = x[:, 2*self.n_gmm:]
        p = p ** 2  # force positive
        p = p / torch.sum(p, dim=1).reshape(-1, 1)  # normalize p to be probability

        return m, s, p

    def loss(self, x, y):
        m, s, p = self.forward(x)
        loss = -torch.log(torch.sum(p*norm_pdf(y, m, s), dim=1) + 1e-7)
        return torch.mean(loss)


class DDD(torch.nn.Module):
    def __init__(self, n_feature, n_hidden1, n_hidden2, n_output, tau=torch.FloatTensor([0.95])):
        super(DDD, self).__init__()

        self.hidden1 = torch.nn.Linear(n_feature, n_hidden1)  # hidden layer
        self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2)  # hidden layer

        self.predict_h = torch.nn.Linear(n_hidden2, n_output)  # output layer
        self.predict_l = torch.nn.Linear(n_hidden2, n_output)  # output layer

        self.predict_h1 = torch.nn.Linear(n_hidden2, n_output)
        self.predict_l1 = torch.nn.Linear(n_hidden2, n_output)

        self.tau = tau.cuda()
        self.dropout = torch.nn.Dropout(0)

        torch.nn.init.xavier_normal_(self.hidden1.weight)
        torch.nn.init.xavier_normal_(self.hidden2.weight)
        torch.nn.init.xavier_normal_(self.predict_h.weight)
        torch.nn.init.xavier_normal_(self.predict_l.weight)
        torch.nn.init.xavier_normal_(self.predict_h1.weight)
        torch.nn.init.xavier_normal_(self.predict_l1.weight)

    def forward(self, m, s, p):
        x = torch.cat((m, s, p), 1).detach()
        x = F.relu((self.hidden1(x)))
        x = self.dropout(x)
        x = F.relu((self.hidden2(x)))
        x = self.dropout(x)

        l = self.predict_l1(x)
        l = l ** 2

        u = self.predict_h1(x)
        u = u ** 2

        L_ = m[:, :l.shape[1]] - l
        U_ = m[:, :l.shape[1]] + u

        # sort by L
        U = U_[torch.arange(L_.shape[0])[:, None], torch.argsort(L_)]
        L = L_[torch.arange(L_.shape[0])[:, None], torch.argsort(L_)]

        # eliminate redundancy
        for i in range(U.shape[1] - 1):
            U[:, i] = L[:, i] + F.relu(U[:, i] - L[:, i])
            L[:, i + 1] = U[:, i] + F.relu(L[:, i + 1] - U[:, i])

        U[:, -1] = L[:, -1] + F.relu(U[:, -1] - L[:, -1])

        return L, U

    def cal_prob(self, L, U, mu, std, pi):
        norm_dis = norm.Normal(torch.tensor([0.0]).cuda(), torch.tensor([1.0]).cuda())

        prob = 0
        for i in range(U.shape[1]):
            prob += torch.sum(pi * (norm_dis.cdf((U[:, i].reshape(-1, 1) - mu) / (std+1e-7))), dim=1)
            prob -= torch.sum(pi * (norm_dis.cdf((L[:, i].reshape(-1, 1) - mu) / (std+1e-7))), dim=1)
        return prob

    def cal_mpiw(self, L, U):
        return torch.mean(torch.sum(U - L, dim=1))

    def cal_acc(self, L, U, Y):
        return torch.mean(torch.sum((L.lt(Y) * Y.lt(U)).float(), dim=1))

    def loss(self, L, U, prob, k):
        L_mpiw = self.cal_mpiw(L, U)
        L_picp = F.relu(self.tau.cuda() - torch.mean(prob))**2

        return L_mpiw + k * L_picp


class MVE(torch.nn.Module):
    def __init__(self, n_feature, n_hidden1, n_hidden2, n_gmm, dropout_prob=0):
        super(MVE, self).__init__()
        self.hidden1 = torch.nn.Linear(n_feature, n_hidden1)
        self.hidden2 = torch.nn.Linear(n_hidden1, n_hidden2)
        self.output = torch.nn.Linear(n_hidden2, n_gmm * 2)
        self.dropout = torch.nn.Dropout(dropout_prob)
        self.n_gmm = n_gmm

        torch.nn.init.xavier_uniform_(self.hidden1.weight)
        torch.nn.init.xavier_uniform_(self.hidden2.weight)
        torch.nn.init.xavier_uniform_(self.output.weight)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        x = self.dropout(x)
        x = self.output(x)

        m = x[:, :self.n_gmm]
        s = x[:, self.n_gmm:]
        s = s ** 2  # force positive

        return m, s

    def loss(self, x, y):
        m, s = self.forward(x)
        loss = -torch.log(torch.sum(norm_pdf(y, m, s), dim=1) + 1e-7)
        return torch.mean(loss)