#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
file  : LISTA_cpss.py
author: Xiaohan Chen
email : chernxh@tamu.edu
date  : 2019-02-17

Implementation of Learned ISTA with support selection and coupled weights.
"""

import numpy as np
import tensorflow as tf

import utils.train
from models.LISTA_base import LISTA_base
from utils.tf import shrink_ss


class LISTA_cpss(LISTA_base):
    """
    Implementation of deep neural network model.
    """

    def __init__(
        self, A, T, lam, rho, beta, percent, max_percent, untied, coord, scope, train=True
    ):
        """
        :prob:     : Instance of Problem class, describing problem settings.
        :T         : Number of layers (depth) of this LISTA model.
        :lam  : Initial value of thresholds of shrinkage functions.
        :untied    : Whether weights are shared within layers.
        """
        self._A = A.astype(np.float32)
        self._T = T
        self._p = percent
        self._maxp = max_percent
        self._lam = lam
        self._M = self._A.shape[0]
        self._N = self._A.shape[1]
        self._rho = rho
        self._beta = beta
        self._train_sam = train

        self._scale = 1.001 * np.linalg.norm(A, ord=2) ** 2
        self._theta = (self._lam / self._scale).astype(np.float32)
        if coord:
            self._theta = np.ones((self._N, 1), dtype=np.float32) * self._theta

        self._ps = [(t + 1) * self._p for t in range(self._T)]
        self._ps = np.clip(self._ps, 0.0, self._maxp)

        self._untied = untied
        self._coord = coord
        self._scope = scope

        """ Set up layers."""
        self.setup_layers()

    def setup_layers(self):
        """
        Implementation of LISTA model proposed by LeCun in 2010.

        :prob: Problem setting.
        :T: Number of layers in LISTA.
        :returns:
            :layers: List of tuples ( name, xh_, var_list )
                :name: description of layers.
                :xh: estimation of sparse code at current layer.
                :var_list: list of variables to be trained seperately.

        """
        Ws_ = []
        thetas_ = []
        rhos_ = []
        betas_ = []

        W = (np.transpose(self._A) / self._scale).astype(np.float32)

        with tf.variable_scope(self._scope, reuse=False) as vs:
            # constant
            self._kA_ = tf.constant(value=self._A, dtype=tf.float32)

            if not self._untied:  # tied model
                Ws_.append(tf.get_variable(name="W", dtype=tf.float32, initializer=W))
                Ws_ = Ws_ * self._T

            for t in range(self._T):
                thetas_.append(
                    tf.get_variable(
                        name="theta_%d" % (t + 1),
                        dtype=tf.float32,
                        initializer=self._theta,
                    )
                )
                if self._untied:  # untied model
                    Ws_.append(
                        tf.get_variable(
                            name="W_%d" % (t + 1), dtype=tf.float32, initializer=W
                        )
                    )
                if self._train_sam:
                    rhos_.append(tf.get_variable(name="rho_%d"%(t+1),
                                                dtype=tf.float32,
                                                initializer=self._rho))
                    betas_.append(tf.get_variable(name="beta_%d"%(t+1),
                                                dtype=tf.float32,
                                                initializer=self._beta))
                

        # Collection of all trainable variables in the model layer by layer.
        # We name it as `vars_in_layer` because we will use it in the manner:
        # vars_in_layer [t]
        if self._train_sam:
            self.vars_in_layer = list(zip(Ws_, thetas_, rhos_, betas_))
        else:
            self.vars_in_layer = list(zip(Ws_, thetas_))
            
    def inference(self, y_, x0_=None):
        xhs_ = []  # collection of the regressed sparse codes

        if x0_ is None:
            batch_size = tf.shape(y_)[-1]
            xh_ = tf.zeros(shape=(self._N, batch_size), dtype=tf.float32)
        else:
            xh_ = x0_
        xhs_.append(xh_)
        g = None

        with tf.variable_scope(self._scope, reuse=True) as vs:
            for t in range(self._T):
                if self._train_sam:
                    W_, theta_, rho_, beta_ = self.vars_in_layer[t]
                else:
                    W_, theta_ = self.vars_in_layer[t]
                    rho_ = self._rho
                    beta_ = self._beta
                percent = self._ps[t]
                if g is None:
                    epsilon =  rho_ * (1.0 - beta_) * tf.zeros_like(xh_)
                else:
                    # g = tf.where(tf.less(xh_, theta_), tf.zeros_like(g), g) 
                    one_over_norm = 1. / tf.sqrt(tf.reduce_sum(tf.square(g), axis=0, keepdims=True) + 1e-20)
                    epsilon = rho_ * (1.0 - beta_ + beta_ * one_over_norm) * g

                zh_ = xh_ + epsilon

                res_ = y_ - tf.matmul(self._kA_, zh_)
                uh_ = xh_ + tf.matmul(W_, res_)
                xh_ = shrink_ss(uh_, theta_, percent)
                g = uh_ - xh_
                xhs_.append(xh_)

        return xhs_
