from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn import datasets
from sklearn import preprocessing
import numpy as np

# import liboptpy.base_optimizer as base
# import liboptpy.constr_solvers as cs
# import liboptpy.step_size as ss
import matplotlib.pyplot as plt
# import liboptpy.unconstr_solvers as us
import tensorflow as tf
from tensorflow import keras
import shutil
import sys
import os
import time

import matplotlib.pyplot as plt
from dataset import read_dataset
import argparse
from utils import ensure_dir

#for every point optimize with gradient descent exponential 
#loss with weight on this point and flipped sign 

def solve_for_x(X, Y, weights
                , init_slope, init_icept, optimizer
                , n_iter=10000, tolerance = 1e-5
                , verbose = False):
    
    tf_slope = tf.Variable(init_slope, dtype='float32') 
    tf_icept = tf.Variable(init_icept, dtype='float32') 
    trainable_params = (tf_slope, tf_icept)
    
    loss_hist = []
    
    n = X.shape[0]
    
    def loss_exponential(w, b):
        loss = 0
        for i in range(n):
            loss += weights[i] * tf.math.exp(-Y[i] * ( tf.reduce_sum(w * X[i,:]) + b) )
        reg_l2 = tf.norm(w)
        return loss + 0.01*reg_l2
    
    for i in range(n_iter):
        
        with tf.GradientTape() as tape:
            # Tell the tape that we want to track the slope and intercept
            tape.watch(trainable_params)
            loss = loss_exponential(tf_slope, tf_icept)

        #Derivative of loss value w.r.t. model parameters
        dloss_dparams = tape.gradient(loss, trainable_params)
        
        optimizer.apply_gradients(zip(dloss_dparams, trainable_params))

        # TODO: add dloss_dparams[0] dloss_dparams[1] to new vector and compute norm of it
       # if tf.norm(dloss_dparams[0]) < tolerance:
       #     if verbose:
       #         print("Required tolerance achieved!")
       #     break
        loss_hist.append(loss)
        
    return tf_slope.numpy(), tf_icept.numpy(), loss_hist

def analyze_points(points_range, X, Y, threshold
                       , mass
                       , w_init_slope, w_init_intercept
                       , optimizer, num_iter = 300, tol = 0.00001
                       , verbose = False, folder = './figures'):
    n = len(Y)
    assert (n > 1)
    
    def zero_one_loss(w, b):
        loss = 0.0
        for i in range(n):
            if Y[i]*(np.dot(w, X[i,:]) + b) <= 0:
                loss += 1
        return loss / n
    
    discarded_indices = []
    
    for index in points_range:
        print("Index", index)
        weights = (1 - mass)/(n-1) * np.ones((n, ))
        weights[index] = mass

        Y_new = Y.copy()
        Y_new[index] = Y_new[index] * -1
        
        w, b, loss_hist = solve_for_x(X, Y_new, weights
                                     , w_init_slope, w_init_intercept
                                     , optimizer, num_iter, tol
                                     , verbose)
        
        if verbose:
            plt.figure()
            plt.plot(range(len(loss_hist)), loss_hist, 'b-', marker='o', label='Computed Loss')
            plt.savefig(folder+"/"+str(index)+".pdf", dpi=150)
            
            plt.figure()
            colors = ['r', 'b']
            for i in range(len(Y)):
                plt.scatter(X[i, 0], X[i, 1], c = colors[int((Y[i]+ 1)/2)])
            plt.scatter(X[index, 0], X[index, 1], c = 'g', s= 100)

            print("zero-one loss", zero_one_loss(w, b), threshold)
            print("w  ", w)
            print("b  ", b)
            print("    exp point loss ", weights[i] * np.exp(-  Y[index] * (np.dot(w, X[index,:]) + b) ))
            plt.plot([0, 1], [-b/w[1], -b/w[1] - w[0]/w[1]], c = 'k')
            
        
        if zero_one_loss(w, b) > threshold:
            if verbose:
                print("point", index, "discarded")
            discarded_indices += [index]

    print(f"\n Discarded {len(discarded_indices)} points out of {n}, {len(discarded_indices)/n}%")
        
    return discarded_indices
    
# main

print(f'Using Python  = {sys.version.split()[0]}')
print(f'Tensorflow    = {tf.__version__}')
print(f'Keras Version = {keras.__version__}')

parser = argparse.ArgumentParser(description='Finds discard points.')
parser.add_argument('datasetdir')
parser.add_argument('runid')
parser.add_argument('-o', '--outdir', default='./jobs_out')
parser.add_argument('-fig', '--figuresdir', default='./figures')
parser.add_argument('-m', '--mass', default=0.1, type=float)
parser.add_argument('-i', '--iterations', default=100, type = int)
parser.add_argument('-lr', '--learningrate', default=0.01, type = float)
parser.add_argument('-v', '--verbose', action='store_true', default=False)

args = parser.parse_args()

outdir = args.outdir + "/" + args.runid
figuresdir = args.figuresdir + "/" + args.runid

ensure_dir(outdir)
ensure_dir(figuresdir)

X, Y, Y1 = read_dataset(args.datasetdir)

#Y_iris1[10] = 1
#Y_iris1[30] = 1
#Y_iris1[70] = -1
#Y_iris1[90] = -1

print(X.shape)
print(Y1.shape)

clf = LogisticRegression(random_state=0, penalty = 'l2', fit_intercept = True).fit(X, Y1)

# need to make it input parameter
theta = 0.03
optimal_model_acc = accuracy_score(clf.predict(X), Y1)
threshold = 1 - optimal_model_acc + theta
print("accuracy", optimal_model_acc)
print('threshold', threshold)

coef_init = clf.coef_[0]
coef_intercept = clf.intercept_[0]

print(coef_init, coef_intercept)

job_id = int(os.getenv('SLURM_ARRAY_TASK_ID')) - 1
job_count = int(os.getenv('SLURM_ARRAY_TASK_COUNT'))
print("Job ID: ", job_id, job_count)

n = X.shape[0]
points_per_job = n//job_count
start = job_id*points_per_job
end = (job_id+1)*points_per_job if job_id < job_count-1 else n
points_range = range(start, end)

#lr_schedule = keras.optimizers.schedules.PolynomialDecay(
#    initial_learning_rate=0.1,
#    decay_steps=20,
#    end_learning_rate=0.001)
#opt = keras.optimizers.Adam(learning_rate=args.learningrate)
#opt = keras.optimizers.Adagrad(learning_rate=args.learningrate)
opt = keras.optimizers.RMSprop(learning_rate=args.learningrate)

tic = time.perf_counter()
discarded_points = analyze_points(points_range, X, Y1, threshold
                       , args.mass
                       , coef_init, coef_intercept
                       , opt
                       , num_iter=args.iterations
                       , verbose=args.verbose
                       , folder=figuresdir)

# save discarded points to file
np.array(discarded_points).tofile(
    outdir + '/discarded_points_{}.csv'.format(job_id),
    sep = ',')
toc = time.perf_counter()

print(f"discarded points in {toc - tic:0.4f} seconds")