#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from datetime import datetime
import numpy as np

class Logistic_Partial_AUC:
    def __init__(self, lr_0=10,lr_outer=0.01, mu=1, num_iter=100, T0=100, batch=100, c=0, fit_intercept=True, verbose=True, k1_value=5, k2_value = 4, seed =1234, dataname='real', Model_name='Logistic'):
        self.lr = 0
        self.lr_outer = lr_outer
        self.num_iter = num_iter
        self.T0 = T0
        self.fit_intercept = fit_intercept
        self.verbose = verbose
        self.k1_value = k1_value
        self.k2_value = k2_value
        self.seed = seed
        self.dataname = dataname
        self.Model_name = Model_name
        self.mu = mu
        self.batch = batch
        self.time_list = []
        self.loss_list = []
        self.data_pass = []
        self.pauc_list = []
        self.lr_0 = lr_0
        self.c = c
    
    def __add_intercept(self, X):
        intercept = torch.ones((X.shape[0], 1))
        return torch.cat((intercept, X), 1)
    
    def __individualloss(self, h):
        return torch.log(1 + torch.exp(-h))
    
    def __grad_loss(self, z):
        return -torch.exp(-z) / (1 + torch.exp(-z))
    
    def __diff_matrix(self, x, y):
        x_m = x.repeat(y.shape[0],1)
        x_m = torch.transpose(x_m, 0, 1)
        y_m = y.repeat(x.shape[0],1)
        return x_m - y_m
    
    def __split_pos_index(self, y):
        index = torch.where(y > 0)
        return index
    
    def __split_neg_index(self, y):
        index = torch.where(y < 0.1)
        return index
        
    def fit(self, X, y,X_val,y_val,X_test, y_test):
        torch.manual_seed(self.seed)
        print('X and y shape:', X.shape, y.shape)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {device} device")

        if self.fit_intercept:
            X = self.__add_intercept(X)
            X_val = self.__add_intercept(X_val)
            X_test = self.__add_intercept(X_test)
            
        X, y = X.to(device), y.to(device)
        # split positive and negtive data
        index_pos = self.__split_pos_index(y)
        index_neg = self.__split_neg_index(y)
        X_pos = X[index_pos]
        y_pos = y[index_pos]
        X_neg = X[index_neg]
        y_neg = y[index_neg]
        
        n_pos = y_pos.shape[0]
        n_neg = y_neg.shape[0]
        d = X.shape[1]
        print(n_pos)
        print(n_neg)
        
        # initialization
        self.w_1 = torch.rand(d)
        self.w_2 = torch.clone(self.w_1)
        # self.w_1 = torch.zeros(d)
        # self.w_2 = torch.zeros(d)
        zz = torch.zeros(d)
        
        self.w_1, self.w_2, zz = self.w_1.to(device), self.w_2.to(device), zz.to(device)
        
        # initialization of lambda
        lamb_1 = torch.zeros(n_pos)
        lamb_2 = torch.zeros(n_pos)
        z_neg = torch.matmul(X_neg, self.w_2)
        z_pos = torch.matmul(X_pos, self.w_2)
        sum_loss = 0
        sum_loss_pauc = 0
        for i in range(n_pos):
            diff = z_pos[i] - z_neg;
            loss = self.__individualloss(diff)
            loss_sort, indices = torch.sort(loss,descending=True)
            lamb_1[i] = loss_sort[self.k1_value-1]
            lamb_2[i] = loss_sort[self.k2_value-1]
            sum_loss += torch.sum(loss_sort[self.k1_value:self.k2_value])
            loss_pauc = 0 - diff;
            loss_pauc[loss_pauc <= 0] = 0
            loss_pauc[loss_pauc > 0] = 1
            loss_pauc_sort, indices_auc = torch.sort(loss_pauc,descending=True)
            sum_loss_pauc += torch.sum(loss_pauc_sort[self.k1_value:self.k2_value])
        sum_loss /= (self.k2_value-self.k1_value) * n_pos
        sum_loss_pauc /= (self.k2_value-self.k1_value) * n_pos
        sum_l = sum_loss.float().cpu().numpy()
        sum_p = sum_loss_pauc.cpu().numpy()
        self.loss_list.append(sum_l)
        self.pauc_list.append(1-sum_p)
        print(sum_l)
        print(1-sum_p)
        self.data_pass.append(0)
        self.time_list.append(0)
        
        lamb_1, lamb_2 = lamb_1.to(device), lamb_2.to(device)
        
        n = y.shape[0]
        # num_bat_per_epoch = n // self.batch
        num_bat_neg = n_neg // self.batch
        num_bat_pos = n_pos // self.batch
        
        time_spent = 0
        datapass = 0
        
        for i in range(self.num_iter):
            
            start_time = datetime.now()
            
            # update learning rate
            self.lr = self.lr_0 / (i+1)
            inner_iter = (i+1)**2 * self.T0
            
            sum_w1 = torch.zeros(d)
            sum_w2 = torch.zeros(d)
            sum_lambda1 = torch.zeros(n_pos)
            sum_lambda2 = torch.zeros(n_pos)
            
            sum_w1, sum_w2 = sum_w1.to(device), sum_w2.to(device)
            sum_lambda1, sum_lambda2 = sum_lambda1.to(device), sum_lambda2.to(device)
     
            for j in range(inner_iter):

                m_pos = j % num_bat_pos
                m_neg = j % num_bat_neg
                if m_pos == 0:
                    index_pos = torch.randperm(n_pos)
                if m_neg == 0:
                    index_neg = torch.randperm(n_neg)
                    
                index_pos_batch = index_pos[m_pos * self.batch:(m_pos+1) * self.batch]
                index_neg_batch = index_neg[m_neg * self.batch:(m_neg+1) * self.batch]
                X_pos_batch = X_pos[index_pos_batch,:]
                X_neg_batch = X_neg[index_neg_batch,:]
                
                # update for g1
                z_1_pos = torch.matmul(X_pos_batch, self.w_1)
                z_1_neg = torch.matmul(X_neg_batch, self.w_1)
                m_1 = self.__diff_matrix(z_1_pos, z_1_neg)
                loss_1 = self.__individualloss(m_1)
                hinge_1 = loss_1 - lamb_1.reshape((n_pos,1))[index_pos_batch]
                loss_1[hinge_1 < 0] = 0
                u_1 = self.__grad_loss(m_1)
                u_1[loss_1==0]=0
                
                # update for g2
                z_2_pos = torch.matmul(X_pos_batch, self.w_2)
                z_2_neg = torch.matmul(X_neg_batch, self.w_2)
                m_2 = self.__diff_matrix(z_2_pos, z_2_neg)
                loss_2 = self.__individualloss(m_2)
                hinge_2 = loss_2 - lamb_2.reshape((n_pos,1))[index_pos_batch]
                loss_2[hinge_2 < 0] = 0
                u_2 = self.__grad_loss(m_2)
                u_2[loss_2==0]=0
                
                u_1, u_2 = u_1.to(device), u_2.to(device)
                
                gradient_1 = torch.zeros(d)
                gradient_2 = torch.zeros(d)
                
                gradient_1, gradient_2 = gradient_1.to(device), gradient_2.to(device)
                
                for p in range(self.batch):
                    diff = X_pos_batch[p,:] - X_neg_batch
                    gradient_1 += torch.matmul(u_1[p,:],diff)
                    gradient_2 += torch.matmul(u_2[p,:],diff)
                gradient_1 /= self.batch * self.batch
                gradient_1 += (self.w_1-zz) / self.mu + self.c * self.w_1
                gradient_2 /= self.batch * self.batch
                gradient_2 += (self.w_2-zz) / self.mu
                
                gradient_lam_1 = self.k1_value / n_neg - torch.count_nonzero(loss_1, dim=1) / self.batch
                self.w_1 -= self.lr * gradient_1
                lamb_1[index_pos_batch] -= self.lr * gradient_lam_1
                lamb_1 = torch.clip(lamb_1, 0.0, 1e7)
                
                gradient_lam_2 = self.k2_value / n_neg - torch.count_nonzero(loss_2, dim=1) / self.batch
                self.w_2 -= self.lr * gradient_2
                lamb_2[index_pos_batch] -= self.lr * gradient_lam_2
                lamb_2 = torch.clip(lamb_2, 0.0, 1e7)
                
                sum_lambda1 = torch.add(sum_lambda1, lamb_1)
                sum_lambda2 = torch.add(sum_lambda2, lamb_2)
                
                sum_w1 = torch.add(sum_w1, self.w_1)
                sum_w2 = torch.add(sum_w2, self.w_2)

            avg_w1 = sum_w1 / inner_iter
            avg_w2 = sum_w2 / inner_iter
            lamb_1 = sum_lambda1 / inner_iter
            lamb_2 = sum_lambda2 / inner_iter
           
            # zz -= self.lr_outer * (self.w_1-self.w_2) / self.mu
            zz -= self.lr_outer * (avg_w1-avg_w2)
            
            end_time = datetime.now()
            time_spent += (end_time - start_time).total_seconds()
            self.time_list.append(time_spent)
            print(time_spent)
           
            # calculate num of data pass
            datapass += inner_iter * 2 * self.batch
            self.data_pass.append(datapass/n)
            
            calculate objective value
            z_neg = torch.matmul(X_neg, self.w_2)
            z_pos = torch.matmul(X_pos, self.w_2)
            sum_loss = 0
            sum_loss_pauc = 0
            for q in range(n_pos):
                diff = z_pos[q] - z_neg;
                loss = self.__individualloss(diff)
                loss_sort, indices = torch.sort(loss,descending=True)
                sum_loss += torch.sum(loss_sort[self.k1_value:self.k2_value])
                loss_pauc = 0 - diff;
                loss_pauc[loss_pauc <= 0] = 0
                loss_pauc[loss_pauc > 0] = 1
                loss_pauc_sort, indices_auc = torch.sort(loss_pauc,descending=True)
                sum_loss_pauc += torch.sum(loss_pauc_sort[self.k1_value:self.k2_value])
            sum_loss /= (self.k2_value-self.k1_value) * n_pos
            sum_loss_pauc /= (self.k2_value-self.k1_value) * n_pos
            sum_l = sum_loss.float().cpu().numpy()
            sum_p = sum_loss_pauc.cpu().numpy()
            self.loss_list.append(sum_l)
            self.pauc_list.append(1-sum_p)
            # print(sum_l)
            print(1-sum_p)
            