import logging
import time

import jax
import jax.numpy as jnp

from meshflow.jax import get_opt_strategy, set_device_mesh
from meshflow.utils.testing import setup_testing, JaxMockDeviceMesh
from meshflow.utils.memory import MemTracking

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.INFO)


def f(x, y):
    tanh = jnp.tanh(x)
    z = jnp.exp(tanh) @ y
    return jnp.mean(z)


def main():

    mesh = JaxMockDeviceMesh(2, 2)
    set_device_mesh(mesh)

    setup_testing(backend="jax", device="cuda")
    mem_track_ = MemTracking()
    mem_track_.start()

    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    x = jax.random.normal(key, (2048, 1024))
    y = jax.random.normal(subkey, (1024, 1024))

    z = f(x, y)

    print(f"mem_track_.summry(): {mem_track_.summary() / 1024 / 1024} MB")

    opt_strategy = get_opt_strategy(f, x, y)

    print(z.shape)


if __name__ == '__main__':
    main()
