import numpy as np
import tensorflow as tf

from datetime import datetime as dt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

def load_model(path, custom_objects=None):
    model = tf.keras.models.load_model(path, custom_objects, False)
    model.compile('adam', tf.nn.sigmoid_cross_entropy_with_logits, ['accuracy'])
    return model

def compute_gradients(model, data, batch_size=200):
    gradients = []
    lower_bound_vec_time = 0
    for i in np.arange(len(data) // batch_size):
        t1 = dt.now()
        with tf.GradientTape() as tape:
            tape.watch(model.variables)
            val = model(data[i*batch_size: (i+1)*batch_size])
        grads = tape.jacobian(val, model.variables)
        t2 = dt.now()
        lower_bound_vec_time += (t2 - t1).total_seconds()
        for j in np.arange(batch_size):
            grad = np.concatenate([
                gr[j].numpy().reshape(-1) for gr in grads if not gr is None
            ])
            gradients.append(grad)
    gradients = np.array(gradients)
    #Normalize to relate the cosine distance to the Euclidean metric
    gradients = gradients / np.linalg.norm(gradients, axis=1, keepdims=True)
    return np.array(gradients), lower_bound_vec_time

def compute_lh_outputs(model, data):
    for layer in reversed(model.layers):
        if 'dense' in layer.name:
            break
    tmp_model = tf.keras.Model(
        inputs=model.get_input_at(0),
        outputs=layer.get_output_at(0),
        name='tmp_model'
    )
    t1 = dt.now()
    lh_outputs = tmp_model(data).numpy()
    t2 = dt.now()
    vec_time = (t2 - t1).total_seconds()
    return lh_outputs, vec_time

def get_vecs_and_time(model, data, method='last_hidden', batch_size=None):
    if model is None:
        return data, None
    if method == 'last_hidden':
        return compute_lh_outputs(model, data)
    elif method == 'gradient':
        return compute_gradients(model, data, batch_size)
    else:
        raise ValueError(f'Unknown method {method}')
        

def get_tril(a):
    res = []
    for i in range(a.shape[0]):
        for j in range(a.shape[1]):
            if i == j:
                break
            res.append(a[i,j])
            
    return np.array(res)