#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
from datetime import datetime

class LogisticRegression_TOPK_DC:
    def __init__(self, lr_0=10, lr_outer=0.01, mu=1, num_iter=100, T0=100, batch=100, fit_intercept=True, verbose=True, k1_value=5, k2_value = 4, seed =1234, dataname='real', Model_name='LogisticRegression'):
        self.lr = 0
        self.lr_outer = lr_outer
        self.num_iter = num_iter
        self.T0 = T0
        self.fit_intercept = fit_intercept
        self.verbose = verbose
        self.k1_value = k1_value
        self.k2_value = k2_value
        self.seed = seed
        self.dataname = dataname
        self.Model_name = Model_name
        self.mu = mu
        self.batch = batch
        self.time_list = []
        self.loss_list = []
        self.data_pass = []
        self.lr_0 = lr_0
        self.seed = seed

    def __add_intercept(self, X):
        intercept = np.ones((X.shape[0], 1))
        return np.concatenate((intercept, X), axis=1)

    def __sigmoid(self, z):
        return 1 / (1 + np.exp(-z))

    def __loss(self, h, y):
        h = np.clip(h, 1e-7, 1.0 - 1e-7)  # clip h, incase h == 0
        return (-y * np.log(h) - (1 - y) * np.log(1 - h)).mean()

    def __individualloss(self, h, y):
        h = np.clip(h, 1e-7, 1.0 - 1e-7)  # clip h, incase h == 0
        return -y * np.log(h) - (1 - y) * np.log(1 - h)

    def fit(self, X, y,X_val,y_val,X_test, y_test):
        np.random.seed(self.seed)
        print('X and y shape:', X.shape, y.shape)

        if self.fit_intercept:
            X = self.__add_intercept(X)
            X_val = self.__add_intercept(X_val)
            X_test = self.__add_intercept(X_test)

        # initialization of w, z
        self.w_1 = np.random.rand(X.shape[1])
        self.w_2 = self.w_1.copy()
        # zz = np.random.rand(X.shape[1])
        
        zz = np.zeros(X.shape[1])
        
        # initialization of lambda
        z_total = np.dot(X, self.w_1)
        h_total = self.__sigmoid(z_total)
        loss_total = self.__individualloss(h_total, y)
        sort_loss_total = np.sort(loss_total)[::-1]
        lamb_1 = sort_loss_total[self.k1_value-1]
        lamb_2 = sort_loss_total[self.k2_value-1]
        
        n = y.size
        num_bat_per_epoch = n // self.batch
        
        # create index of data points
        index = np.arange(X.shape[0]) # number of data points
        
        time_spent = 0
        datapass = 0
        
        # calculate objective value
        diff_loss = np.sum(sort_loss_total[self.k1_value:self.k2_value]) / (self.k2_value-self.k1_value)
        self.loss_list.append(diff_loss)
        self.data_pass.append(0)
        
        for i in range(self.num_iter):
            start_time = datetime.now()

            # update learning rate
            self.lr = self.lr_0 / (i+1)
            
            # if i < 6:
            inner_iter = (i+1)**2 * self.T0
                
            sum_w1 = np.zeros(self.w_1.size)
            sum_w2 = np.zeros(self.w_1.size)
            sum_lambda1 = 0
            sum_lambda2 = 0
            
            for j in range(inner_iter):
                m = np.mod(j,num_bat_per_epoch)
                if m == 0:
                    np.random.shuffle(index)
                index_batch = index[m * self.batch:(m+1) * self.batch]
                X_batch = X[index_batch,:]
                y_batch = y[index_batch]
                
                # update for g1
                z_1 = np.dot(X_batch, self.w_1)
                h_1 = self.__sigmoid(z_1)
                loss_1 = self.__individualloss(h_1, y_batch)
                hinge_1 = loss_1 - lamb_1
                loss_1[hinge_1 < 0] = 0
                u_1 = (h_1 - y_batch)
                u_1[loss_1==0]=0
                
                gradient_1 = np.dot(X_batch.T, u_1) / self.batch + (self.w_1-zz) / self.mu
                gradient_lam_1 = self.k1_value / n - np.count_nonzero(loss_1) / y_batch.size
                self.w_1 -= self.lr * gradient_1
                lamb_1 -= self.lr * gradient_lam_1
                lamb_1 = np.clip(lamb_1, 0.0, 1e7)
                   
                # update for g2
                z_2 = np.dot(X_batch, self.w_2)
                h_2 = self.__sigmoid(z_2)
                loss_2 = self.__individualloss(h_2, y_batch)
                hinge_2 = loss_2 - lamb_2
                loss_2[hinge_2 < 0] = 0
                u_2 = (h_2 - y_batch)
                u_2[loss_2==0]=0  
                
                gradient_2 = np.dot(X_batch.T, u_2) / self.batch + (self.w_2-zz) / self.mu
                gradient_lam_2 = self.k2_value / n - np.count_nonzero(loss_2) / self.batch
                self.w_2 -= self.lr * gradient_2
                lamb_2 -= self.lr * gradient_lam_2
                lamb_2 = np.clip(lamb_2, 0.0, 1e7)
                
                sum_lambda1 += lamb_1
                sum_lambda2 += lamb_2
                
                sum_w1 = np.add(sum_w1, self.w_1)
                sum_w2 = np.add(sum_w2, self.w_2)

            avg_w1 = sum_w1 / inner_iter
            avg_w2 = sum_w2 / inner_iter
            lamb_1 = sum_lambda1 / inner_iter
            lamb_2 = sum_lambda2 / inner_iter
            # zz -= self.lr_outer * (avg_w1-avg_w2) / self.mu
            zz -= self.lr_outer * (avg_w1-avg_w2)
            
            # calculate num of data pass
            datapass += inner_iter * self.batch
            self.data_pass.append(datapass/n)
            
            z = np.dot(X, self.w_2)
            h = self.__sigmoid(z)
            loss = self.__individualloss(h, y)
            sorted_loss = np.sort(loss)[::-1]
            diff_loss = np.sum(sorted_loss[self.k1_value:self.k2_value]) / (self.k2_value-self.k1_value)
            self.loss_list.append(diff_loss)
            
            end_time = datetime.now()
            time_spent += (end_time - start_time).total_seconds()
            self.time_list.append(time_spent)
