"""
This file contains the functions for matrix and vector operations.
"""

import jax
import jax.numpy as jnp

def vec_dot(v1, v2):
   """
   Compute the dot product of two vectors.

   Args:
      v1 (jax.tree_util.PyTreeDef, required):
         The first vector for the dot product.
      v2 (jax.tree_util.PyTreeDef, required):
         The second vector for the dot product.

   Returns:
      jax.numpy.DeviceArray:
         The scalar dot product of `v1` and `v2`.
   """
   return jax.tree_util.tree_reduce(
      lambda acc, val: acc + jnp.sum(val),
      jax.tree_util.tree_map(lambda x, y: x * y, v1, v2),
      0.0
   )

def vec_scale(alpha, v):
   """
   Scale each element in the input vector by a scalar.

   Args:
      alpha (float, required):
         The scalar value to multiply each element by.
      v (jax.tree_util.PyTreeDef, required):
         The input vector to be scaled.

   Returns:
      jax.tree_util.PyTreeDef:
         The scaled vector.
   """
   return jax.tree_util.tree_map(lambda x: alpha * x, v)

def vec_sum(v1, v2):
   """
   Compute the element-wise addition of two vectors: x + y.

   Args:
      v1 (jax.tree_util.PyTreeDef, required):
         The first vector.
      v2 (jax.tree_util.PyTreeDef, required):
         The second vector.

   Returns:
      jax.tree_util.PyTreeDef:
         The result of `v1 + v2`.
   """
   return jax.tree_util.tree_map(lambda x, y: x + y, v1, v2)

# This version is not used in the code
#def vec_axpy(alpha, v1, v2):
   """
   Compute the operation alpha * x + y.

   Args:
      alpha (float, required):
         The scalar multiplier for `v1`.
      v1 (jax.tree_util.PyTreeDef, required):
         The first vector.
      v2 (jax.tree_util.PyTreeDef, required):
         The second vector.

   Returns:
      jax.tree_util.PyTreeDef:
         The result of `alpha * v1 + v2`.
   """
#   return jax.tree_util.tree_map(lambda x, y: alpha * x + y, v1, v2)

def vec_axpy(alpha, v1, beta, v2):
   """
   Compute the operation alpha * x + beta * y.

   Args:
      alpha (float, required):
         The scalar multiplier for `v1`.
      v1 (jax.tree_util.PyTreeDef, required):
         The first vector.
      beta (float, required):
         The scalar multiplier for `v2`.
      v2 (jax.tree_util.PyTreeDef, required):
         The second vector.

   Returns:
      jax.tree_util.PyTreeDef:
         The result of `alpha * v1 + beta * v2`.
   """
   return jax.tree_util.tree_map(lambda x, y: alpha * x + beta * y, v1, v2)

def vec_pow(v, p):
   """
   Compute element-wise exponentiation of a vector: v^p.

   Args:
      v (jax.tree_util.PyTreeDef, required):
         The input vector.
      p (float, required):
         The exponent to raise each element to.

   Returns:
      jax.tree_util.PyTreeDef:
         The result of raising each element of `v` to the power `p`.
   """
   return jax.tree_util.tree_map(lambda x: x ** p, v)

def vec_norm_inf(v):
   """
   Compute the infinity norm of a vector.

   Args:
      v (jax.tree_util.PyTreeDef, required):
         The input vector.

   Returns:
      jax.numpy.DeviceArray:
         The infinity norm (maximum absolute value) of the vector.
   """
   return jax.tree_util.tree_reduce(
      lambda acc, x: jnp.maximum(acc, jnp.max(jnp.abs(x))),
      v,
      0.0
   )

def vec_normalize(v, eps=1e-03):
   """
   Normalize a vector by dividing each element by the norm of the vector plus a small constant.

   Args:
      v (jax.tree_util.PyTreeDef, required):
         The input vector.
      eps (float, optional):
         The small constant to add to the norm of the vector. Defaults to 1e-03.

   Returns:
      jax.tree_util.PyTreeDef:
         The normalized vector.
   """
   return jax.tree_util.tree_map(lambda x: x / (jnp.linalg.norm(x) + eps), v)


def mat_vec(M, v):
   """
   Compute the dot product of a matrix and a vector.

   Args:
      M (jax.tree_util.PyTreeDef, required):
         The matrix for the dot product.
      v (jax.tree_util.PyTreeDef, required):
         The second vector for the dot product.

   Returns:
      jax.numpy.DeviceArray:
         The result of `M * v`.
   """
   return jax.tree_util.tree_reduce(
      lambda acc, val: acc + val,
      jax.tree_util.tree_map(
         lambda m, x: jnp.einsum('r...,...->r', m, x),
         M, v
      ),
      0.0
   )

def mat_bvec(M, v):
   """
   Compute the dot product of a matrix and a block vector.

   Args:
      M (jax.tree_util.PyTreeDef, required):
         The matrix for the dot product, with shape (r,...)
      v (jax.tree_util.PyTreeDef, required):
         The block vector with shape (r,...) 

   Returns:
      jax.numpy.DeviceArray:
         The result of `M * v`, shape (r,r)
   """
   return jax.tree_util.tree_reduce(
      lambda acc, val: acc + val,
      jax.tree_util.tree_map(
         lambda m, x: jnp.einsum('r...,s...->rs', m, x),
         M, v
      ),
      0.0
   )

def vec_mat(v, M):
   """Computes linear combination of pytree "rows" with given weights.
   
   Args:
      v: Array of weights with shape (rank,)
      M: Pytree representing matrix where first dimension is rank
      
   Returns:
      Pytree representing the weighted sum of rows
   """
   return jax.tree_util.tree_map(
      lambda leaf: jnp.einsum('i,i...->...', v, leaf),
      M
   )

def bvec_mat(v, M):
   """
   Compute linear combination of pytree "rows" with given block weights.
   
   Args:
      v (jax.numpy.ndarray, required):
         Array of block weights with shape (rank, n)
      M (jax.tree_util.PyTreeDef, required):
         Pytree representing matrix where first dimension is rank
      
   Returns:
      jax.tree_util.PyTreeDef:
         The result of block vector-matrix product
   """
   # For each leaf in M, compute the block vector-matrix product
   return jax.tree_util.tree_map(
      lambda leaf: jnp.einsum('bi,i...->b...', v, leaf),
      M
   )

