"""First pass at small-scale computation of derivative of coefficient wrt to PEF/input."""
import dataclasses

import tensorflow as tf


@dataclasses.dataclass
class CoeffGrad1:
    # shape = [rank, m]
    H: tf.Tensor

    # shape = [m]
    fisher: tf.Tensor

    n_w_fit_steps: int

    def __init__(self):
        self.rank, self.m = self.H.shape

        self._construct_abc()

        # shape = [rank]
        self.w = tf.Variable(tf.random.uniform([self.rank]), dtype=tf.float32)

    def _construct_abc(self):
        zero_r = tf.zeros([self.rank], dtype=tf.float32)

        self.A = tf.concat([
            tf.concat([[[-1]], zero_r[None, :]], axis=1),
            tf.concat([tf.zeros([self.m, 1], dtype=tf.float32), -tf.transpose(self.H)], axis=1),
            tf.concat([zero_r[:, None], -tf.eye(self.rank)], axis=1),
        ], axis=0)

        self.b = tf.concat([
            [0],
            -self.fisher,
            zero_r,
        ], axis=0)

        self.c = tf.concat([
            [1],
            zero_r,
        ], axis=0)

    def run(self):
        self._fit_w()

        x1 = tf.linalg.norm(tf.einsum('ji,j->i', self.H, self.w) - self.fisher)
        self.x = tf.concat([[x1], self.w], axis=0)

    @tf.function
    def _fit_w(self):
        xh = tf.einsum('i,ji->j', self.fisher, self.H)
        hh = tf.einsum('ij,kj->ik', self.H, self.H)
        for _ in tf.range(self.n_w_fit_steps):
            whh = tf.einsum('i,ij->j', self.w, hh)
            self.w.assign(self.w * xh / (whh + 1e-9))

