#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from datetime import datetime
import torch

class Logistic_AUC_proximal:
    def __init__(self, L=0.1, lr_0=1, num_iter=100, T0=1000, batch=100, c=0, fit_intercept=True, verbose=True, k_value=5, k2_value = 4, seed =1234, dataname='real', Model_name='LogisticRegression'):
        self.lr = 0
        self.lr_0 = lr_0
        self.num_iter = num_iter
        self.T0 = T0
        self.fit_intercept = fit_intercept
        self.verbose = verbose
        self.k_value = k_value
        self.k2_value = k2_value
        self.seed = seed
        self.dataname = dataname
        self.Model_name = Model_name
        self.batch = batch
        self.time_list = []
        self.loss_list = []
        self.data_pass = []
        self.pauc_list = []
        self.seed = seed
        self.c = c
        self.L = L

    def __add_intercept(self, X):
        intercept = torch.ones((X.shape[0], 1))
        return torch.cat((intercept, X), 1)

    def __individualloss(self, h):
        return torch.log(1 + torch.exp(-h))

    def __grad_loss(self, z):
        return -torch.exp(-z) / (1 + torch.exp(-z))

    def __diff_matrix(self, x, y):
        x_m = x.repeat(y.shape[0],1)
        x_m = torch.transpose(x_m, 0, 1)
        y_m = y.repeat(x.shape[0],1)
        return x_m - y_m

    def __split_pos_index(self, y):
        index = torch.where(y > 0)
        return index

    def __split_neg_index(self, y):
        index = torch.where(y < 0.1)
        return index

    def fit(self, X, y, X_val, y_val, X_test, y_test):
        torch.manual_seed(self.seed)
        print('X and y shape:', X.shape, y.shape)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {device} device")

        if self.fit_intercept:
            X = self.__add_intercept(X)
            X_val = self.__add_intercept(X_val)
            X_test = self.__add_intercept(X_test)

        X, y = X.to(device), y.to(device)
        # split positive and negtive data
        index_pos = self.__split_pos_index(y)
        index_neg = self.__split_neg_index(y)
        X_pos = X[index_pos]
        y_pos = y[index_pos]
        X_neg = X[index_neg]
        y_neg = y[index_neg]

        n_pos = y_pos.shape[0]
        n_neg = y_neg.shape[0]
        d = X.shape[1]
        print(n_pos)
        print(n_neg)

        # initialization
        self.theta = torch.rand(d)
        # self.theta = torch.zeros(d)

        self.theta = self.theta.to(device)
        theta_t = torch.clone(self.theta)

        # initialization of lambda
        z_neg = torch.matmul(X_neg, self.theta)
        z_pos = torch.matmul(X_pos, self.theta)
        sum_loss = 0
        sum_loss_pauc = 0
        for i in range(n_pos):
            diff = z_pos[i] - z_neg
            loss = self.__individualloss(diff)
            loss_sort, indices = torch.sort(loss,descending=True)
            sum_loss += torch.sum(loss_sort[self.k2_value:self.k_value])
            loss_pauc = 0 - diff
            loss_pauc[loss_pauc <= 0] = 0
            loss_pauc[loss_pauc > 0] = 1
            loss_pauc_sort, indices_auc = torch.sort(loss_pauc,descending=True)
            sum_loss_pauc += torch.sum(loss_pauc_sort[self.k2_value:self.k_value])
        sum_loss /= (self.k_value-self.k2_value) * n_pos
        sum_loss_pauc /= (self.k_value-self.k2_value) * n_pos
        sum_l = sum_loss.float().cpu().numpy()
        sum_p = sum_loss_pauc.cpu().numpy()
        self.loss_list.append(sum_l)
        self.pauc_list.append(1-sum_p)
        print(sum_l)
        print(1-sum_p)
        self.data_pass.append(0)
        self.time_list.append(0)

        n = y.shape[0]
        num_bat_neg = n_neg // self.batch
        num_bat_pos = n_pos // self.batch

        time_spent = 0
        datapass = 0

        for i in range(self.num_iter):

            start_time = datetime.now()

            lamb_1 = torch.zeros(n_pos, device = device)
            lamb_2 = torch.zeros(n_pos, device = device)

            z_neg = torch.matmul(X_neg, self.theta)
            z_pos = torch.matmul(X_pos, self.theta)

            subgradient = torch.zeros(d, device = device)
            for p in range(n_pos):
                diff = z_pos[i] - z_neg;
                loss = self.__individualloss(diff)
                loss_sort, indices = torch.sort(loss,descending=True)
                lamb_1[i] = loss_sort[self.k2_value-1]
                lamb_2[i] = loss_sort[self.k_value-1]
                hinge = loss - lamb_1[i]
                loss[hinge < 0] = 0
                u = self.__grad_loss(diff)
                u[loss == 0]=0
                X_diff = X_pos[p,:] - X_neg
                subgradient += torch.matmul(u,X_diff)
            subgradient /= n_pos * n_neg

            self.lr = self.lr_0 / (i+1)
            inner_iter = (i+1)**2 * self.T0

            # self.theta = np.random.rand(X.shape[1])
            # self.theta = np.zeros((self.theta.shape[0],))

            for j in range(inner_iter):
                # sample mini batch
                m_pos = j % num_bat_pos
                m_neg = j % num_bat_neg
                if m_pos == 0:
                    index_pos = torch.randperm(n_pos)
                if m_neg == 0:
                    index_neg = torch.randperm(n_neg)

                index_pos_batch = index_pos[m_pos * self.batch:(m_pos+1) * self.batch]
                index_neg_batch = index_neg[m_neg * self.batch:(m_neg+1) * self.batch]
                X_pos_batch = X_pos[index_pos_batch,:]
                X_neg_batch = X_neg[index_neg_batch,:]

                z_2_pos = torch.matmul(X_pos_batch, self.theta)
                z_2_neg = torch.matmul(X_neg_batch, self.theta)
                m_2 = self.__diff_matrix(z_2_pos, z_2_neg)
                loss_2 = self.__individualloss(m_2)
                hinge_2 = loss_2 - lamb_2.reshape((n_pos,1))[index_pos_batch]
                loss_2[hinge_2 < 0] = 0
                u_2 = self.__grad_loss(m_2)
                u_2[loss_2==0]=0

                # u_2 = u_2.to(device)

                gradient_2 = torch.zeros(d, device = device)

                # gradient_2 = gradient_2.to(device)

                for p in range(self.batch):
                    diff = X_pos_batch[p,:] - X_neg_batch
                    gradient_2 += torch.matmul(u_2[p,:],diff)
                gradient_2 /= self.batch * self.batch
                # For proximal dca
                gradient_2 += - subgradient + self.L * (self.theta - theta_t)

                gradient_lam_2 = self.k2_value / n_neg - torch.count_nonzero(loss_2, dim=1) / self.batch
                self.theta -= self.lr * gradient_2
                lamb_2[index_pos_batch] -= self.lr * gradient_lam_2
                lamb_2 = torch.clip(lamb_2, 0.0, 1e7)

            # update theta_t
            theta_t = torch.clone(self.theta)

            end_time = datetime.now()
            time_spent += (end_time - start_time).total_seconds()
            self.time_list.append(time_spent)

            # calculate num of data pass
            datapass += inner_iter * 2 * self.batch
            self.data_pass.append(datapass/n)

            # calculate objective value
            z_neg = torch.matmul(X_neg, self.theta)
            z_pos = torch.matmul(X_pos, self.theta)
            sum_loss = 0
            sum_loss_pauc = 0
            for q in range(n_pos):
                diff = z_pos[q] - z_neg
                loss = self.__individualloss(diff)
                loss_sort, indices = torch.sort(loss,descending=True)
                sum_loss += torch.sum(loss_sort[self.k2_value:self.k_value])
                loss_pauc = 0 - diff
                loss_pauc[loss_pauc <= 0] = 0
                loss_pauc[loss_pauc > 0] = 1
                loss_pauc_sort, indices_auc = torch.sort(loss_pauc,descending=True)
                sum_loss_pauc += torch.sum(loss_pauc_sort[self.k2_value:self.k_value])
            sum_loss /= (self.k_value-self.k2_value) * n_pos
            sum_loss_pauc /= (self.k_value-self.k2_value) * n_pos
            sum_l = sum_loss.float().cpu().numpy()
            sum_p = sum_loss_pauc.cpu().numpy()
            self.loss_list.append(sum_l)
            self.pauc_list.append(1-sum_p)
            print(sum_l)
            print(1-sum_p)
