"""
Tools for manipulating sets of variables.
"""

import numpy as np
import tensorflow as tf

def interpolate_vars(old_vars, new_vars, epsilon):
    """
    Interpolate between two sequences of variables.
    """
    return add_vars(old_vars, scale_vars(subtract_vars(new_vars, old_vars), epsilon))

def eigvector_vars(old_vars, new_vars, epsilon):
    """
    Interpolate between two sequences of variables.
    """
    # print("***************************************8")
    # print(np.array(old_vars).shape)
    after_scale = scale_vars( new_vars , epsilon )
    g = get_grad( after_scale )
    vec = []
    for parameter in old_vars:

        eig_w = np.array([next(g) for i in range( num_para( parameter.shape ) )])
        vec.append( eig_w.reshape(np.array(parameter).shape) )

    return add_vars( old_vars , vec )



def average_vars(var_seqs):
    """
    Average a sequence of variable sequences.
    """
    res = []
    for variables in zip(*var_seqs):
        res.append(np.mean(variables, axis=0))
    return res

def subtract_vars(var_seq_1, var_seq_2):
    """
    Subtract one variable sequence from another.
    """
    return [v1 - v2 for v1, v2 in zip(var_seq_1, var_seq_2)]

def add_vars(var_seq_1, var_seq_2):
    """
    Add two variable sequences.
    """
    return [v1 + v2 for v1, v2 in zip(var_seq_1, var_seq_2)]

def scale_vars(var_seq, scale):
    """
    Scale a variable sequence.
    """
    return [v * scale for v in var_seq]

def weight_decay(rate, variables=None):
    """
    Create an Op that performs weight decay.
    """
    if variables is None:
        variables = tf.trainable_variables()
    ops = [tf.assign(var, var * rate) for var in variables]
    return tf.group(*ops)

def get_grad(a):
    for q in a:
        yield q


def num_para(x):
    num=1
    for i in x:
        num *= i
    return num

class VariableState:
    """
    Manage the state of a set of variables.
    """
    def __init__(self, session, variables):
        self._session = session
        self._variables = variables
        self._placeholders = [tf.placeholder(v.dtype.base_dtype, shape=v.get_shape())
                              for v in variables]
        assigns = [tf.assign(v, p) for v, p in zip(self._variables, self._placeholders)]
        self._assign_op = tf.group(*assigns)

    def export_variables(self):
        """
        Save the current variables.
        """
        return self._session.run(self._variables)

    def import_variables(self, values):
        """
        Restore the variables.
        """
        self._session.run(self._assign_op, feed_dict=dict(zip(self._placeholders, values)))
