#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from datetime import datetime
import torch
import math

class Logistic_AUC_base:
    def __init__(self, lr_0=1, num_iter=100, T0=1000, batch=100, c=0, fit_intercept=True, verbose=True, alpha=0.05, beta=0.5, seed =1234, dataname='real', Model_name='LogisticRegression'):
        self.lr = 0
        self.lr_0 = lr_0
        self.num_iter = num_iter
        self.T0 = T0
        self.fit_intercept = fit_intercept
        self.verbose = verbose
        self.alpha = alpha
        self.beta = beta
        self.seed = seed
        self.dataname = dataname
        self.Model_name = Model_name
        self.batch = batch
        self.time_list = []
        self.loss_list = []
        self.data_pass = []
        self.pauc_list = []
        self.seed = seed
        self.c = c

    def __add_intercept(self, X):
        intercept = torch.ones((X.shape[0], 1))
        return torch.cat((intercept, X), 1)

    def __individualloss(self, h):
        return torch.log(1 + torch.exp(-h))
    
    def __grad_loss(self, z):
        return -torch.exp(-z) / (1 + torch.exp(-z))
    
    def __diff_matrix(self, x, y):
        x_m = x.repeat(y.shape[0],1)
        x_m = torch.transpose(x_m, 0, 1)
        y_m = y.repeat(x.shape[0],1)
        return x_m - y_m
    
    def __split_pos_index(self, y):
        index = torch.where(y > 0)
        return index
    
    def __split_neg_index(self, y):
        index = torch.where(y < 0.1)
        return index

    def fit(self, X, y, X_val, y_val, X_test, y_test):
        torch.manual_seed(self.seed)
        print('X and y shape:', X.shape, y.shape)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {device} device")

        if self.fit_intercept:
            X = self.__add_intercept(X)
            X_val = self.__add_intercept(X_val)
            X_test = self.__add_intercept(X_test)
            
        X, y = X.to(device), y.to(device)
        X_val, y_val = X_val.to(device), y_val.to(device)
        # split positive and negtive data
        index_pos = self.__split_pos_index(y)
        index_neg = self.__split_neg_index(y)
        X_pos = X[index_pos]
        y_pos = y[index_pos]
        X_neg = X[index_neg]
        y_neg = y[index_neg]
        
        n_pos = y_pos.shape[0]
        n_neg = y_neg.shape[0]
        d = X.shape[1]
        print(n_pos)
        print(n_neg)
        
        self.k2_value = int(math.floor(self.alpha * n_neg))
        self.k_value = int(math.ceil(self.beta * n_neg))
        
        index_val_pos = self.__split_pos_index(y_val)
        index_val_neg = self.__split_neg_index(y_val)
        X_val_pos = X_val[index_val_pos]
        y_val_pos = y_val[index_val_pos]
        X_val_neg = X_val[index_val_neg]
        y_val_neg = y_val[index_val_neg]
        
        n_val_pos = y_val_pos.shape[0]
        n_val_neg = y_val_neg.shape[0]
        print(n_val_pos)
        print(n_val_neg)
        
        self.k2_value_val = int(math.floor(self.alpha * n_val_neg))
        self.k_value_val = int(math.ceil(self.beta * n_val_neg))
        
        # initialization
        self.best_theta = torch.zeros(d)
        self.theta = torch.rand(d)
        
        self.theta = self.theta.to(device)
        
        pauc_initial = 0
        pauc_list = []
        
        # initialization of lambda
        z_neg = torch.matmul(X_neg, self.theta)
        z_pos = torch.matmul(X_pos, self.theta)
        sum_loss = 0
        sum_loss_pauc = 0
        for i in range(n_pos):
            diff = z_pos[i] - z_neg;
            loss = self.__individualloss(diff)
            loss_sort, indices = torch.sort(loss,descending=True)
            sum_loss += torch.sum(loss_sort[self.k2_value:self.k_value])
            loss_pauc = 0 - diff;
            loss_pauc[loss_pauc <= 0] = 0
            loss_pauc[loss_pauc > 0] = 1
            loss_pauc_sort, indices_auc = torch.sort(loss_pauc,descending=True)
            sum_loss_pauc += torch.sum(loss_pauc_sort[self.k2_value:self.k_value])
        sum_loss /= (self.k_value-self.k2_value) * n_pos
        sum_loss_pauc /= (self.k_value-self.k2_value) * n_pos
        sum_l = sum_loss.float().cpu().numpy()
        sum_p = sum_loss_pauc.cpu().numpy()
        self.loss_list.append(sum_l)
        self.pauc_list.append(1-sum_p)
        print(sum_l)
        print(1-sum_p)
        self.data_pass.append(0)
        self.time_list.append(0)
        
        n = y.shape[0]
        num_bat_neg = n_neg // self.batch
        num_bat_pos = n_pos // self.batch
        
        time_spent = 0
        datapass = 0
        
        for i in range(self.num_iter):
            
            start_time = datetime.now()
            
            lamb_1 = torch.zeros(n_pos, device = device)
            lamb_2 = torch.zeros(n_pos, device = device)
            
            z_neg = torch.matmul(X_neg, self.theta)
            z_pos = torch.matmul(X_pos, self.theta)
            
            subgradient = torch.zeros(d, device = device)
            for p in range(n_pos):
                diff = z_pos[i] - z_neg;
                loss = self.__individualloss(diff)
                loss_sort, indices = torch.sort(loss,descending=True)
                lamb_1[i] = loss_sort[self.k2_value-1]
                lamb_2[i] = loss_sort[self.k_value-1]
                hinge = loss - lamb_1[i]
                loss[hinge < 0] = 0
                u = self.__grad_loss(diff)
                u[loss == 0]=0
                X_diff = X_pos[p,:] - X_neg
                subgradient += torch.matmul(u,X_diff)
            subgradient /= n_pos * n_neg
            
            self.lr = self.lr_0 / (i+1)
            inner_iter = (i+1)**2 * self.T0
            
            for j in range(inner_iter):
                # sample mini batch
                m_pos = j % num_bat_pos
                m_neg = j % num_bat_neg
                if m_pos == 0:
                    index_pos = torch.randperm(n_pos)
                if m_neg == 0:
                    index_neg = torch.randperm(n_neg)
                    
                index_pos_batch = index_pos[m_pos * self.batch:(m_pos+1) * self.batch]
                index_neg_batch = index_neg[m_neg * self.batch:(m_neg+1) * self.batch]
                X_pos_batch = X_pos[index_pos_batch,:]
                X_neg_batch = X_neg[index_neg_batch,:]
                
                z_2_pos = torch.matmul(X_pos_batch, self.theta)
                z_2_neg = torch.matmul(X_neg_batch, self.theta)
                m_2 = self.__diff_matrix(z_2_pos, z_2_neg)
                loss_2 = self.__individualloss(m_2)
                hinge_2 = loss_2 - lamb_2.reshape((n_pos,1))[index_pos_batch]
                loss_2[hinge_2 < 0] = 0
                u_2 = self.__grad_loss(m_2)
                u_2[loss_2==0]=0
                
                gradient_2 = torch.zeros(d, device = device)
                
                for p in range(self.batch):
                    diff = X_pos_batch[p,:] - X_neg_batch
                    gradient_2 += torch.matmul(u_2[p,:],diff)
                gradient_2 /= self.batch * self.batch
                gradient_2 += self.c * self.theta - subgradient
                
                gradient_lam_2 = self.k2_value / n_neg - torch.count_nonzero(loss_2, dim=1) / self.batch
                self.theta -= self.lr * gradient_2
                lamb_2[index_pos_batch] -= self.lr * gradient_lam_2
                lamb_2 = torch.clip(lamb_2, 0.0, 1e7)
            
            end_time = datetime.now()
            time_spent += (end_time - start_time).total_seconds()
            self.time_list.append(time_spent)
           
            # calculate num of data pass
            datapass += inner_iter * 2 * self.batch
            self.data_pass.append(datapass/n)
            
            # calculate valuation value
            z_neg = torch.matmul(X_val_neg, self.theta)
            z_pos = torch.matmul(X_val_pos, self.theta)
            sum_loss_pauc = 0
            for q in range(n_val_pos):
                diff = z_pos[q] - z_neg;
                loss_pauc = 0 - diff;
                loss_pauc[loss_pauc <= 0] = 0
                loss_pauc[loss_pauc > 0] = 1
                loss_pauc_sort, indices_auc = torch.sort(loss_pauc,descending=True)
                sum_loss_pauc += torch.sum(loss_pauc_sort[self.k2_value_val:self.k_value_val])
            sum_loss_pauc /= (self.k_value_val-self.k2_value_val) * n_val_pos
            sum_p = sum_loss_pauc.cpu().numpy()
        
            if 1-sum_p >= pauc_initial:
                pauc_list.append(1-sum_p)
                pauc_initial = max(pauc_list)
                self.best_theta = torch.clone(self.theta)
            