import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as opt
from sklearn.svm import OneClassSVM


class OC_SVM(OneClassSVM):
    def __init__(self, dim, nu=0.5, gamma=None, kernel="linear", **kwargs):
        if gamma is None:
            gamma = 0.2 / dim
        self.pu = False
        self.ae = True
        super().__init__(nu=nu, gamma=gamma, kernel=kernel)

    def run_train(self, dataset):
        lab_data = dataset.lab_data(lab=1)
        self.fit(np.array(lab_data.data))

    def predict(self, dataset, pi=None):
        y_pred = self.decision_function(dataset.data)
        if pi is not None:
            threshold = np.percentile(y_pred, 100 * (1 - pi))
        else:
            threshold = 0
        return (y_pred > threshold) * 2 - 1

    def decision_function(self, dataset):
        return super().decision_function(np.array(dataset.data))


class RE_SVM:
    def __init__(self, dim, kernel="poly", gamma=1, degree=1, coef0=0.0, pi=0.5, lam=0.01, **kwargs):
        if kernel == "poly":
            self.w = torch.rand(dim, 1, requires_grad=True)
        elif kernel == "rbf":
            self.w = torch.rand(dim, requires_grad=True)
        else:
            raise ValueError("wrong kernel")
        self.b = 1

        self.coef0 = coef0
        self.degree = degree
        self.kernel = kernel
        self.gamma = gamma
        self.dim = dim
        self.pi = pi
        self.lam = lam

        self.pu = True
        self.ae = True

    def decision_function(self, x, **kwargs):

        y_pred = np.array([])
        testloader = torch.utils.data.DataLoader(x.data,
                                                 batch_size=50,
                                                 num_workers=1)

        for x_test in testloader:
            res = self.predict_one(x_test.float())
            y_pred = np.hstack((y_pred, res.squeeze().detach().cpu().numpy()))

        return y_pred

    def predict_one(self, x, **kwargs):
        x = torch.Tensor(x)
        if self.kernel == "rbf":
            return torch.exp(-self.gamma * (torch.pow(torch.norm(x - self.w, dim=-1), 2))) - self.b
        elif self.kernel == "poly":
            return torch.pow(self.coef0 + self.gamma * torch.mm(x, self.w), self.degree) - self.b

    def predict(self, x, **kwargs):
        return (self.decision_function(x) > 0) * 2 - 1

    def loss(self, y_pred, y_true):
        pos_label = y_true == 1
        unl_label = y_true == -1

        pos_pred = y_pred[pos_label]
        unl_pred = y_pred[unl_label]

        L_pos = 0
        L_neg = 0
        L_unl = 0

        if len(pos_label):
            L_pos = F.relu(torch.max(1 - pos_pred, -2 * pos_pred)).mean()
            L_neg = -F.relu(torch.max(1 + pos_pred, 2 * pos_pred)).mean()

        if len(unl_label):
            L_unl = F.relu(torch.max(1 + unl_pred, 2 * unl_pred)).mean()

        if L_unl + self.pi * L_neg > 0:
            L = self.lam * torch.pow(torch.norm(self.w), 2) + L_pos * self.pi + L_unl + self.pi * L_neg
        else:
            L = self.lam * torch.pow(torch.norm(self.w), 2) + L_pos * self.pi

        y_pred = y_pred[len(y_pred) // 2:]
        # 1
        self.b += np.percentile(y_pred.detach().cpu().numpy(), q=100 * (1 - self.pi))

        # 2
        # self.b += np.percentile(y_pred.detach().cpu().numpy(), q=100 * (1 - self.pi) / 2)

        return L - self.b

    def run_train(self,
                  train_data,
                  num_epochs=300,
                  lr=5e-3,
                  batch_size=512,
                  lr_gamma=0.995,
                  verbose=False,
                  test_data=None):

        lab_data = train_data.lab_data(lab=1)
        unl_data = train_data.lab_data(lab=-1)

        lab_loader = torch.utils.data.DataLoader(lab_data,
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=1)

        unl_loader = torch.utils.data.DataLoader(unl_data,
                                                 batch_size=batch_size,
                                                 shuffle=True,
                                                 num_workers=1)

        full_loader = torch.utils.data.DataLoader(train_data,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=1)

        optim = opt.Adam([self.w], lr=lr)
        scheduler = opt.lr_scheduler.ExponentialLR(optim, gamma=lr_gamma)

        for epoch in range(num_epochs):
            running_loss = 0.0
            for (lab_data, unl_data) in zip(lab_loader, unl_loader):
                # for (data, _, label) in full_loader:
                if len(lab_data[0]) != len(unl_data[0]):
                    continue
                data = torch.cat((lab_data[0], unl_data[0]))
                label = torch.cat((torch.ones(len(lab_data[0])), -torch.ones(len(unl_data[0]))))
                data = data.float()

                y_pred = self.predict_one(data)
                L = self.loss(y_pred, label)

                if L != 0 and L == L:
                    optim.zero_grad()
                    L.backward()
                    optim.step()
                    running_loss += L.item()

# class EN_SVM:
#     def __init__(self, dim, gamma=None, kernel="rbf", **kwargs):
#         if gamma is None:
#             gamma = 2 / dim
#         self.pu = True
#         self.ae = True
#         self.c = None
#         self.pi = None
#         self.clf = SVC(gamma=gamma, kernel=kernel)
#
#     def run_train(self, dataset, test_size=0.2):
#         data = np.array(dataset.data)
#         targets = np.array(dataset.s)
#         # lab_portion = (targets == 1).sum() / len(targets)
#
#         data, val_data, targets, val_tar = train_test_split(data, targets, test_size=test_size)
#
#         lab, unl = targets == 1, targets == -1
#         lab_data, unl_data = data[lab], data[unl]
#
#         if len(lab_data) > len(unl_data):
#             lab_data = lab_data[:len(unl_data)]
#         else:
#             unl_data = unl_data[:len(lab_data)]
#
#         dataset = np.concatenate((lab_data, unl_data))
#         targets = np.concatenate((np.ones(len(lab_data)), np.zeros(len(unl_data))))
#
#         self.val = val_data, val_tar
#         self.clf.fit(dataset, targets)
#         self.clf_plt = CalibratedClassifierCV(self.clf, method="sigmoid").fit(dataset, targets)
#
#         self.c = self.clf_plt.predict_proba(val_data[val_tar == -1])[:, 1].max()
#         self.pi_est = (1 - self.c) / self.c
#
#     def predict(self, dataset, pi=None):
#         if pi is None:
#             pi = self.pi_est
#         return (self.decision_function(dataset.data) * pi > 0.5) * 2 - 1
#
#     def decision_function(self, dataset):
#         # return self.clf.decision_function(dataset)
#         res = self.clf_plt.predict_proba(dataset.data)[:, 1]
#         return res / (1 - res)
