import tensorflow as tf
import numpy as np
from . import min_norm_solvers

lr = 50

def ComputeGradient(gradients):
    # 1. Flat gradients
    gs = []
    for i in range(len(gradients)): # for each task
        g_task_flat = tf.concat([tf.reshape(grad, [-1]) for grad in gradients[i]], 0)
        gs.append(g_task_flat)

    # 2. Compute the weight
    weights = distangle_optimize_weight_with_distance(gs)

    # 3. Obtain the final gradient
    d = []
    for k in range(len(gradients[0])): # for each layer
        g = 0
        for i in range(len(gradients)):
            g += weights[i]*gradients[i][k]
        d.append(g)
    return d


def distangle_optimize_weight_with_distance(grads):
    """optimize the distance for all independent grad
    grads: List, all flat gradients
    """
    W = tf.Variable((1/(len(grads)-1))*tf.ones([len(grads), len(grads)], dtype=tf.float64), trainable=True)
    
    # compute the masks
    W_diag_mask = np.identity(len(grads))
    W_mask = 1 - W_diag_mask

    optimizer = tf.keras.optimizers.SGD(lr)
    
    if len(grads) > 2:
        # update off-diagonal
        for _ in range(5):
            with tf.GradientTape() as tape:
                # compute total loss
                masked_W = masked_softmax(W, W_mask)    
                G_combine = tf.matmul(masked_W, grads)
                ##########################################################
                # AGD
                maxdo_loss = 0.1*asymmetric_distance(G_combine, grads)
                ##########################################################
            g_W = tape.gradient(maxdo_loss, W) 
            optimizer.apply_gradients([(g_W, W)])

    # update diagonal
    W_diag = min_norm_solvers.find_min_norm_element_independent(grads)

    if len(grads) > 2:
        W = masked_softmax(W, W_mask) + tf.linalg.diag(W_diag)
        return tf.reduce_sum(W, 0) / (len(grads)+1)
    else:  # only for two tasks
        W = tf.linalg.diag(W_diag)
        return tf.nn.softmax(tf.reduce_sum(W, 0) + 1., -1)

def asymmetric_distance(x, y):
    """The proposed distance
    rad(x,y) -> Int >=0
    """
    dist = tf.math.reduce_euclidean_norm( x - y, axis=-1 )
    dist = dist / (dist + tf.math.reduce_euclidean_norm( y, axis=-1 ))
    dist = tf.reduce_mean(dist)
    return dist

def masked_softmax(scores, mask):
    scores = scores - tf.tile(tf.reduce_max(scores, axis=(1,), keepdims=True), [1, tf.shape(scores)[1]])
    exp_scores = tf.exp(scores)
    exp_scores = tf.math.multiply(exp_scores, mask)
    exp_sum_scores = tf.reduce_sum(exp_scores, axis=1, keepdims=True)
    return exp_scores / (tf.tile(exp_sum_scores, [1, tf.shape(exp_scores)[1]])+1e-7) 