import jax
import jax.numpy as jnp
import jax.lax as lax


def f(v0,v1,v2,v4):
    v5 = jnp.power(v2,3)
    v6 = lax.dot_general(v4,v5,dimension_numbers=(((0,), (0,)), ((), ())))
    v7 = lax.dot_general(v5,v6,dimension_numbers=(((0,), (1,)), ((), ())))
    v8 = jnp.power(v6,2)
    v9 = lax.dot_general(v0,v8,dimension_numbers=(((0,), (1,)), ((), ())))
    v10 = jnp.sqrt(v0)
    v11 = jnp.arccos(v7)
    v12 = jnp.transpose(v1,axes=[0, 1])
    v13 = lax.dot_general(v11,v8,dimension_numbers=(((1,), (0,)), ((), ())))
    _v9 = v9.reshape([1, 1])
    v14 = jnp.add(v11,_v9)
    _v9 = v9.reshape([1, 1])
    v15 = jnp.subtract(v14,_v9)
    v16 = jnp.cos(v9)
    _v10 = jnp.reshape(v10, [4, 1])
    v17 = jnp.add(_v10,v13)
    v18 = jnp.arcsinh(v17)
    _v16 = v16.reshape([1, 1])
    v19 = jnp.add(v15,_v16)
    v20 = jnp.negative(v12)
    v21 = jnp.sqrt(v10)
    v22 = lax.dot_general(v21,v18,dimension_numbers=(((0,), (1,)), ((), ())))
    v23 = jnp.tanh(v12)
    v24 = jnp.arctan(v19)
    v25 = jnp.square(v24)
    v26 = jnp.cos(v22)
    v27 = jnp.subtract(_v16,v25)
    v28 = jnp.cosh(v23)
    v29 = jnp.sum(v20, axis=0)
    v30 = jnp.transpose(v28,axes=[1, 0])
    v31 = jnp.sum(v16, axis=0)
    v32 = jnp.squeeze(v26)
    _v29 = jnp.reshape(v29, [3, 1])
    v33 = jnp.subtract(_v29,v30)
    v34 = jnp.divide(v31,v33)
    _v10 = v10.reshape([4, 1])
    v35 = jnp.add(v13,_v10)
    _v32 = jnp.reshape(v32, [4, 1])
    v36 = jnp.multiply(_v32,v35)
    v37 = jnp.power(v32,2)
    v38 = jnp.amax(v34, axis=1)
    v39 = jnp.sinh(v36)
    v40 = jnp.exp(v27)
    v41 = lax.slice(v39, start_indices=[3, 3], limit_indices=[4, 4])
    v42 = jnp.sum(v37, axis=0)
    v43 = jnp.add(v42,v38)
    v44 = lax.dot_general(v32,v40,dimension_numbers=(((0,), (0,)), ((), ())))
    v45 = jnp.arcsinh(v43)
    v46 = lax.dot_general(v24,v15,dimension_numbers=(((1,), (1,)), ((), ())))
    v47 = lax.stop_gradient(v41)
    v48 = lax.logistic(v45)
    v49 = lax.dot_general(v29,v48,dimension_numbers=(((0,), (0,)), ((), ())))
    v50 = lax.dot_general(v22,v21,dimension_numbers=(((0,), (0,)), ((), ())))
    v51 = lax.dot_general(v44,v47,dimension_numbers=(((0,), (1,)), ((), ())))
    v52 = lax.dot_general(v30,v33,dimension_numbers=(((0,), (0,)), ((), ())))
    v53 = jnp.multiply(v46,v50)
    v54 = jnp.sqrt(v37)
    v55 = jnp.sqrt(jnp.abs(v51))
    v56 = jnp.subtract(v52,v49)
    v57 = jnp.cosh(v54)
    v58 = jnp.negative(v37)
    _v57 = jnp.reshape(v57, [4, 1])
    v59 = jnp.add(_v57,v53)
    v60 = jnp.square(v57)
    v61 = jnp.sin(v58)
    v62 = jnp.arctan(v59)
    v63 = jnp.sum(v62, axis=0)
    v64 = jnp.amin(v63, axis=0)
    v65 = jnp.subtract(v64,v55)
    v66 = jnp.sqrt(v2)
    v67 = jnp.sum(v56, axis=1)
    _v61 = jnp.reshape(v61, [4, 1])
    v68 = jnp.subtract(_v61,v66)
    v69 = lax.dot_general(v60,v61,dimension_numbers=(((0,), (0,)), ((), ())))
    v70 = lax.dot_general(v68,v62,dimension_numbers=(((0,), (0,)), ((), ())))
    _v22 = jnp.reshape(v22, [4, 1])
    v71 = jnp.add(_v22,v2)
    v72 = lax.logistic(v70)
    v73 = lax.dot_general(v72,v71,dimension_numbers=(((0,), (0,)), ((), ())))
    _v37 = v37.reshape([1, 4])
    v74 = jnp.divide(v6,_v37)
    return v74,v65,v67,v69,v73


# fwd: 632, rev: 566, mM: 451
def g(v0,v1,v2,v3,v4,v5,v6,v7,v8,v9,v10,v11,v12,v13,v14):
    v15 = jnp.ones(())
    v16 = jnp.ones(())
    v17 = jnp.arctan2(v15,v3)
    v18 = jnp.power(v15,v14)
    v19 = jnp.arctan2(v13,v17)
    v20 = jnp.power(v17,v5)
    v21 = jnp.subtract(v9,v8)
    v22 = jnp.arctan(v10)
    v23 = jnp.multiply(v21,v0)
    v24 = jnp.subtract(v22,v18)
    v25 = jnp.power(v24,v11)
    v26 = jnp.add(v21,v12)
    v27 = jnp.exp(v2)
    v28 = jnp.add(v21,v26)
    v29 = jnp.power(v16,v17)
    v30 = jnp.sin(v3)
    v31 = jnp.divide(v2,v23)
    v32 = jnp.arctan2(v24,v26)
    v33 = jnp.subtract(v29,v25)
    v34 = jnp.arctan(v33)
    v35 = jnp.multiply(v7,v27)
    v36 = jnp.tan(v13)
    v37 = jnp.add(v27,v19)
    v38 = jnp.divide(v7,v6)
    v39 = jnp.subtract(v25,v34)
    v40 = jnp.subtract(v29,v6)
    v41 = jnp.power(v35,v4)
    v42 = jnp.cos(v0)
    v43 = jnp.subtract(v40,v19)
    v44 = jnp.divide(v41,v7)
    v45 = jnp.subtract(v40,v8)
    v46 = jnp.arctan2(v16,v43)
    v47 = jnp.divide(v8,v30)
    v48 = jnp.power(v37,v36)
    v49 = jnp.power(v47,v44)
    v50 = jnp.arctan2(v45,v40)
    v51 = jnp.arctan2(v44,v13)
    v52 = jnp.divide(v4,v34)
    v53 = jnp.arctan2(v41,v1)
    v54 = jnp.arctan2(v25,v53)
    v55 = jnp.add(v12,v51)
    v56 = jnp.power(v3,v48)
    v57 = jnp.subtract(v24,v32)
    v58 = jnp.arctanh(v48)
    v59 = jnp.arctanh(v54)
    v60 = jnp.arctan2(v2,v26)
    v61 = jnp.subtract(v52,v18)
    v62 = jnp.cos(v38)
    v63 = jnp.divide(v8,v32)
    v64 = jnp.arctan2(v42,v46)
    v65 = jnp.sinh(v36)
    v66 = jnp.subtract(v61,v50)
    v67 = jnp.power(v13,v39)
    v68 = jnp.power(v37,v26)
    v69 = jnp.subtract(v68,v43)
    v70 = jnp.log(v65)
    v71 = jnp.power(v23,v58)
    v72 = jnp.arctan2(v50,v69)
    v73 = jnp.divide(v20,v72)
    v74 = jnp.multiply(v41,v56)
    v75 = jnp.multiply(v39,v33)
    v76 = jnp.multiply(v61,v15)
    v77 = jnp.power(v66,v64)
    v78 = jnp.arctan2(v53,v25)
    v79 = jnp.subtract(v59,v60)
    v80 = jnp.arctan2(v73,v41)
    v81 = jnp.multiply(v74,v46)
    v82 = jnp.square(v5)
    v83 = jnp.arctan2(v28,v62)
    v84 = jnp.arctan2(v39,v9)
    v85 = jnp.multiply(v4,v12)
    v86 = jnp.divide(v57,v56)
    v87 = jnp.arctan2(v61,v63)
    v88 = jnp.arcsinh(v14)
    v89 = jnp.power(v86,v83)
    v90 = jnp.arcsin(v66)
    v91 = jnp.subtract(v70,v79)
    v92 = jnp.arctan2(v37,v65)
    v93 = jnp.multiply(v67,v77)
    v94 = jnp.power(v87,v55)
    v95 = jnp.square(v49)
    v96 = jnp.divide(v92,v94)
    v97 = jnp.add(v15,v95)
    v98 = jnp.divide(v91,v76)
    v99 = jnp.arctanh(v46)
    v100 = jnp.multiply(v75,v45)
    v101 = jnp.divide(v100,v28)
    v102 = jnp.add(v90,v29)
    v103 = jnp.arctan2(v80,v84)
    v104 = jnp.subtract(v20,v91)
    v105 = jnp.add(v82,v90)
    v106 = jnp.divide(v102,v103)
    v107 = jnp.power(v75,v42)
    v108 = jnp.multiply(v68,v62)
    v109 = jnp.divide(v102,v1)
    v110 = jnp.power(v98,v101)
    v111 = jnp.negative(v100)
    v112 = jnp.divide(v31,v89)
    v113 = jnp.divide(v105,v108)
    v114 = jnp.subtract(v22,v70)
    v115 = jnp.multiply(v106,v44)
    v116 = jnp.subtract(v107,v81)
    v117 = jnp.power(v58,v115)
    v118 = jnp.subtract(v78,v108)
    v119 = jnp.tan(v87)
    v120 = jnp.power(v87,v48)
    v121 = jnp.add(v105,v38)
    return v121,v71,v85,v88,v93,v96,v97,v99,v104,v109,v110,v111,v112,v113,v114,v116,v117,v118,v119,v120

