# mpirun -np 2 python ./examples/jax/test_sharding_simple.py

import logging

import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec
from mpi4py import MPI

import meshflow as mf
from meshflow.jax import get_opt_strategy, meshflow_shard, set_device_mesh

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.INFO)


def f(x, y):
    z = jnp.exp(jnp.tanh(x)) @ y
    return z


def main():

    mf.platform.init_backend("jax")

    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    jax.distributed.initialize(coordinator_address="localhost:19705",
                               num_processes=size,
                               process_id=rank,
                               local_device_ids=rank)

    print(
        f"[Rank {rank}], Global Devices: {jax.device_count()}, Local Devices: {jax.local_device_count()}"
    )
    print(jax.devices())

    devices = mesh_utils.create_device_mesh((2, 2))
    mesh = Mesh(devices, axis_names=('a', 'b'))

    set_device_mesh(mesh)

    key = jax.random.PRNGKey(1)
    key, subkey = jax.random.split(key)
    x = jax.random.normal(key, (4, 10, 10))
    y = jax.random.normal(subkey, (10, 10))

    opt_strategy = get_opt_strategy(f, x, y)

    print(" =========== Sharding ==============")

    x_shard = multihost_utils.host_local_array_to_global_array(x, mesh, PartitionSpec(None, None))
    y_shard = multihost_utils.host_local_array_to_global_array(y, mesh, PartitionSpec(None, None))

    sharded_f = jax.jit(meshflow_shard(f, opt_strategy))

    with jax.spmd_mode('allow_all'):
        z = f(x, y)
        z_sharded = sharded_f(x_shard, y_shard)
        if rank == 0:
            print("z.device_buffer:", z.device_buffer.shape)
            print("z_sharded.device_buffer:", z_sharded.device_buffer.shape)

    if rank == 0:
        closed_jaxpr = jax.make_jaxpr(sharded_f)(x, y)
        print(closed_jaxpr.jaxpr)
        print(closed_jaxpr.literals)


if __name__ == '__main__':
    main()
