import numpy as np
import torch
import torch.optim as opt
from torch import nn
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class classifier_nn(nn.Module):
    def __init__(self, D):
        super(classifier_nn, self).__init__()
        self.model = nn.Sequential(
            nn.ReLU(),
            nn.Linear(D, 1),
            nn.Sigmoid())

    def forward(self, x):
        return self.model(x)


def get_resnet(pre_trained_flag=True):
    model = models.resnet18(pretrained=pre_trained_flag)
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 4096),
                             nn.ReLU(),
                             nn.Dropout(0.5),
                             nn.Linear(4096, 4096))
    return model


class CNNClassifier:
    """
        Base class for cnn-based classifiers
    """

    def __init__(self, D=4096, pi=0.5, **kwargs):
        self.model = get_resnet()
        self.classifier = classifier_nn(D)

        self.model.to(device)
        self.classifier.to(device)

        # self.model_type = model_type
        self.device = device
        self.D = D

        self.pi = pi
        self.pu = None
        self.inm = nn.InstanceNorm1d(1, affine=False)

    def get_labels(self, batch_x, batch_s):
        pass

    def forward(self, batch_x):
        pass

    def decision_function(self, x):
        raise NotImplementedError()

    def predict(self, x, pi):
        raise NotImplementedError()

    def run_train(self,
                  train_data,
                  num_epochs=10,
                  lr=1e-4,
                  batch_size=256,
                  test=False,
                  test_data=None,
                  **kwargs):

        if test and test_data is None:
            raise ValueError("tesloader can't be none if test=True")

        lab_data = train_data.lab_data(lab=1)
        unl_data = train_data.lab_data(lab=0)

        lab_loader = torch.utils.data.DataLoader(lab_data,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        unl_loader = torch.utils.data.DataLoader(unl_data,
                                                 batch_size=batch_size,
                                                 shuffle=True)
        # self.to(device)

        loss_function = nn.BCELoss()
        # if self.model_type == "resnet18":
        self.opt_model = opt.Adam(self.model.fc.parameters(), lr=lr)
        # else:
        #     self.opt_model = opt.Adam(self.model.classifier.parameters(), lr=lr)

        self.opt_class = opt.Adam(self.classifier.parameters(), lr=lr)

        for i in range(num_epochs):
            running_loss = 0.0
            for (lab_data, unl_data) in zip(lab_loader, unl_loader):

                if self.pu:
                    if len(lab_data[0]) != len(unl_data[0]):
                        continue
                    batch_x = torch.cat((lab_data[0], unl_data[0]))
                    batch_s = torch.cat((torch.ones(len(lab_data[0])), torch.zeros(len(lab_data[0]))))
                else:
                    batch_x = lab_data[0]
                    batch_s = self.get_labels(batch_x)

                # training mode
                # self.model.train()
                self.classifier.train()

                # get labels
                labels = batch_s.unsqueeze(1).to(device)

                # forward
                out = self.forward(batch_x)

                # backward
                self.opt_model.zero_grad()
                self.opt_class.zero_grad()

                loss = loss_function(out, labels.float())
                loss.backward()

                self.opt_model.step()
                self.opt_class.step()

                # loss
                running_loss += loss.data

    def eval(self):
        self.model.eval()
        self.classifier.eval()


class OC_CNN(CNNClassifier):

    def name(self):
        return "OC-CNN"

    def __init__(self, D=4096, pi=0.5, sigma=0.1, **kwargs):
        self.ae = False
        self.pu = False
        self.sigma = sigma
        super().__init__(D, pi)

    def get_labels(self, batch_x):
        labels = np.concatenate((np.ones((int(batch_x.shape[0]),)),
                                 np.zeros((int(batch_x.shape[0]),))),
                                axis=0)
        labels = torch.from_numpy(labels)
        labels = torch.autograd.Variable(labels.to(self.device))
        return labels

    def forward(self, batch_x):

        if batch_x.shape[1] == 1:
            batch_x = batch_x.repeat(1, 3, 1, 1)
        # make fake negative examples
        gaussian_data = np.random.normal(0, self.sigma, (int(batch_x.shape[0]), self.D))
        gaussian_data = torch.from_numpy(gaussian_data)

        # forward
        out1 = self.model(batch_x.to(device))
        out1 = out1.view(int(batch_x.shape[0]), 1, self.D)
        out1 = self.inm(out1)
        out1 = out1.view(int(batch_x.shape[0]), self.D)
        self.res = out1
        out2 = torch.autograd.Variable(gaussian_data.to(device)).float()
        out = torch.cat((out1, out2), 0)
        out = self.classifier(out)
        return out

    def decision_function(self, x, pi=None):
        # self.eval()
        y_pred = np.array([])
        testloader = torch.utils.data.DataLoader(x.data,
                                                 batch_size=100)

        for batch_x in testloader:
            if batch_x.shape[1] == 1:
                batch_x = batch_x.repeat(1, 3, 1, 1)
            test_out = self.model(batch_x.to(device))
            test_out = test_out.view(int(batch_x.shape[0]), 1, self.D)
            test_out = self.inm(test_out)
            test_out = test_out.view(int(batch_x.shape[0]), self.D)
            res = self.classifier(test_out)
            y_pred = np.hstack((y_pred, res.squeeze().detach().cpu().numpy()))

        return y_pred

    def predict(self, x, pi=None):
        y_pred = self.decision_function(x)
        if not self.pu and pi is not None:
            threshold = np.percentile(y_pred, 100 * (1 - self.pi))
        else:
            threshold = 0.5
        return (y_pred > threshold) * 1


class PULoss:
    def __init__(self, prior, loss, beta=0, nnpu=True):
        self.prior = prior
        self.loss = loss
        self.beta = beta
        self.nnpu = nnpu

    def __call__(self, batch_x, batch_s):
        pos_labels, unl_labels = batch_s == 1, batch_s == 0
        batch_x_pos, batch_x_unl = batch_x[pos_labels], batch_x[unl_labels]
        pos_risk = self.prior * self.loss(batch_x_pos,
                                          torch.autograd.Variable(batch_s[pos_labels].to(device)).unsqueeze(1).float())
        neg_risk_1 = self.loss(batch_x_unl,
                               torch.autograd.Variable(batch_s[unl_labels].to(device)).unsqueeze(1).float())
        neg_risk_2 = self.prior * self.loss(batch_x_pos,
                                            torch.autograd.Variable((1 - batch_s[pos_labels]).to(device)).unsqueeze(
                                                1).float())

        if self.nnpu and neg_risk_1 - neg_risk_2 <= -self.beta:
            return pos_risk
        else:
            return pos_risk + neg_risk_1 - neg_risk_2


class PU_CNN(CNNClassifier):
    # def name(self):
    #     if self.nnpu:
    #         return "nnPU-CNN"
    #     return "uPU-CNN"

    def __init__(self, D=4096, nnpu=True, beta=0, pi=0.5, **kwargs):
        self.nnpu = nnpu
        self.beta = beta

        self.ae = False
        self.pu = True
        self.tr = 0.5

        # if pi is None or not 0 <= pi <= 1:
        #     raise ValueError("prior must be initialized and be in [0; 1]")
        super().__init__(D, pi)

    def forward(self, batch_x):
        if batch_x.shape[1] == 1:
            batch_x = batch_x.repeat(1, 3, 1, 1)
        out1 = self.model(batch_x.to(device))
        out1 = out1.view(int(batch_x.shape[0]), 1, self.D)
        out1 = self.inm(out1)
        out1 = out1.view(int(batch_x.shape[0]), self.D)
        out = self.classifier(out1)
        return out

    def decision_function(self, x, pi=None):
        y_pred = np.array([])
        testloader = torch.utils.data.DataLoader(x.data,
                                                 batch_size=100,
                                                 num_workers=2)

        for batch_x in testloader:
            if batch_x.shape[1] == 1:
                batch_x = batch_x.repeat(1, 3, 1, 1)

            res = self.forward(batch_x)
            y_pred = np.hstack((y_pred, res.squeeze().detach().cpu().numpy()))

        return y_pred

    def predict(self, x, pi=None):

        y_pred = self.decision_function(x)

        return (y_pred > self.tr) * 1

    def run_train(self,
                  train_data,
                  num_epochs=10,
                  lr=1e-4,
                  batch_size=256,
                  test_data=None,
                  **kwargs):

        trainloader = torch.utils.data.DataLoader(train_data,
                                                  batch_size=batch_size,
                                                  shuffle=True)

        lab_data = train_data.lab_data(lab=1)
        unl_data = train_data.lab_data(lab=0)

        lab_loader = torch.utils.data.DataLoader(lab_data,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        unl_loader = torch.utils.data.DataLoader(unl_data,
                                                 batch_size=batch_size,
                                                 shuffle=True)

        # if self.model_type == "resnet18":
        self.opt_model = opt.Adam(self.model.fc.parameters(), lr=lr)
        # else:
        #     self.opt_model = opt.Adam(self.model.classifier.parameters(), lr=lr)

        self.opt_class = opt.Adam(self.classifier.parameters(), lr=lr)
        loss_function = PULoss(self.pi, nn.BCELoss(), self.beta, self.nnpu)

        for i in range(num_epochs):
            running_loss = 0.0
            for (batch_x, _, batch_s) in trainloader:
                # for (lab_data, unl_data) in zip(lab_loader, unl_loader):
                # if len(lab_data[0]) != len(unl_data[0]):
                #     continue
                # batch_x = torch.cat((lab_data[0], unl_data[0]))
                # batch_s = torch.cat((torch.ones(len(lab_data[0])), torch.zeros(len(unl_data[0]))))

                # training mode
                # self.model.train()
                self.classifier.train()

                if len(batch_s) == 1:
                    continue

                # forward
                out = self.forward(batch_x)
                # backward
                self.opt_model.zero_grad()
                self.opt_class.zero_grad()

                loss = loss_function(out, batch_s)
                loss.backward()

                self.opt_model.step()
                self.opt_class.step()

                self.tr = 0.9 * self.tr + 0.1 * np.percentile(out.cpu().detach().numpy(), q=100 * (1 - self.pi))

                # loss
                running_loss += loss.data

# class EN_CNN(CNNClassifier):
#
#     def __init__(self, model_type="resnet18", D=4096, pi=0.5, **kwargs):
#         self.ae = False
#         self.pu = True
#         self.c = None
#         super().__init__(model_type, D, pi)
#
#     def get_labels(self, batch_x, batch_s):
#         return torch.autograd.Variable(batch_s.to(self.device)).unsqueeze(1)
#
#     def run_train(self,
#                   train_data,
#                   num_epochs=10,
#                   lr=1e-4,
#                   batch_size=256,
#                   test=False,
#                   test_data=None,
#                   **kwargs):
#
#         train_data, val_data, train_y, val_y, train_s, val_s = train_test_split(train_data.data,
#                                                                                 train_data.y,
#                                                                                 train_data.s,
#                                                                                 test_size=0.2)
#
#         self.val_data = Dataset(val_data, val_y, val_s)
#         train_data = Dataset(train_data, train_y, train_s)
#
#         super().run_train(train_data,
#                           num_epochs=num_epochs,
#                           lr=lr,
#                           batch_size=batch_size,
#                           test_data=test_data)
#         self.set_c()
#
#     def set_c(self):
#         self.eval()
#
#         # data = self.val_data.lab_data(lab=1)
#
#         valloader = torch.utils.data.DataLoader(self.val_data,
#                                                 batch_size=100,
#                                                 shuffle=True)
#
#         c = 1
#         sm = 0
#         nm = 0
#         for (batch_x, _, batch_s) in valloader:
#             labeled = batch_s == 0
#
#             batch_x = batch_x[labeled]
#
#             if len(batch_x) == 0:
#                 continue
#
#             nm += len(batch_x)
#             out = self.decision_function(batch_x)
#             out = (1 - out) / out
#             # sm += out.sum()
#             c = min(c, np.min(out))
#
#         # self.c = sm / nm
#         self.c = c
#         self.pi_est = c
#
#     def forward(self, batch_x):
#         if batch_x.shape[1] == 1:
#             batch_x = batch_x.repeat(1, 3, 1, 1)
#         out1 = self.model(batch_x.to(device))
#         out1 = out1.view(int(batch_x.shape[0]), 1, self.D)
#         out1 = self.inm(out1)
#         out1 = out1.view(int(batch_x.shape[0]), self.D)
#         out = self.classifier(out1)
#         return out
#
#     def decision_function(self, x, pi=None):
#
#         if isinstance(x, Dataset):
#             x = x.data
#
#         y_pred = np.array([])
#         testloader = torch.utils.data.DataLoader(x.data,
#                                                  batch_size=100)
#
#         for batch_x in testloader:
#             if batch_x.shape[1] == 1:
#                 batch_x = batch_x.repeat(1, 3, 1, 1)
#
#             res = self.forward(batch_x)
#             y_pred = np.hstack((y_pred, res.squeeze().detach().cpu().numpy()))
#
#         return y_pred
#
#     def predict(self, x, pi=None):
#
#         if pi is None:
#             # if self.c is None:
#             #     self.set_c()
#             pi = self.pi_est
#
#         y_pred = self.decision_function(x)
#         y_pred = y_pred / (1 - y_pred)
#         return (pi * y_pred > 0.5) * 1
