# mpirun --mca btl tcp,self -np 2 python ./benchmark/bench_jax.py

import os
import sys
import logging

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from mpi4py import MPI

from meshflow.jax import get_opt_strategy, meshflow_shard, set_device_mesh
from meshflow.jax.api import shard_module, to_shape_array
from meshflow.jax.model import GATLayer
from meshflow.jax.model.gpt import GPTSimple
from meshflow.jax.model.wresnet import wresnet50
from meshflow.utils.testing import setup_testing
from meshflow.utils.timer import MFTimer
from meshflow.utils.memory import MemTracking

sys.path.append(os.path.abspath(__file__))
from benchmark.bench_case import GPTCase, ResNetCase, GATCase

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.INFO)


def get_gpt_case():
    case = GPTCase()
    model = GPTSimple(case)

    root_key = jax.random.PRNGKey(seed=0)
    main_key, params_key = jax.random.split(key=root_key)
    input_ = jax.random.normal(
        main_key, (case.batch_size, case.seq_size, case.hidden_dim))  # Dummy input data
    variables = model.init(params_key, input_, deterministic=True)
    params = variables['params']

    def train_step(params, input_):
        lr = 0.0001

        def loss_fn(params):
            dropout_key = jax.random.PRNGKey(seed=0)
            return model.apply({
                'params': params
            },
                               input_,
                               deterministic=False,
                               rngs={
                                   'dropout': dropout_key
                               }).mean()

        grad_fn = jax.grad(loss_fn)
        grads = grad_fn(params)
        params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
        return params

    return train_step, [params, input_]


def get_resnet_case():
    case = ResNetCase()
    model = wresnet50()

    key1, key2 = jax.random.split(jax.random.PRNGKey(0), num=2)
    input_ = jax.random.normal(key1, (case.batch_size, 224, 224, 3))  # Dummy input data
    variables = model.init(key2, input_)  # Initialization call
    params, batch_stats = variables['params'], variables['batch_stats']

    def train_step(params, batch_stats, input_):
        lr = 0.0001

        def loss_fn(params, batch_stats):
            out_, batch_stats = model.apply({
                'params': params,
                'batch_stats': batch_stats
            },
                                            input_,
                                            mutable=['batch_stats'])
            return out_.mean()

        grad_fn = jax.grad(loss_fn)
        grads = grad_fn(params, batch_stats)
        params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
        return params

    return train_step, [params, batch_stats, input_]


def get_gat_case():

    case = GATCase()
    model = GATLayer(case.in_feature, case.out_feature)

    key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), num=3)
    h = jax.random.normal(key1, (case.num_node, case.in_feature))  # Dummy input data
    adj = jax.random.normal(key2, (case.num_node, case.num_node))  # Dummy input data
    variables = model.init(key3, h, adj)  # Initialization call
    params = variables['params']

    def train_step(params, h, adj):
        lr = 0.0001

        def loss_fn(params):
            return model.apply({'params': params}, h, adj).mean()

        grad_fn = jax.grad(loss_fn)
        grads = grad_fn(params)
        params = jax.tree_map(lambda x, y: x - lr * y, params, grads)
        return params

    return train_step, [params, h, adj]


def bench_naive(func, args):
    
    jit_func = jax.jit(func)
    
    def train_step():
        jit_func(*args)

    timer = MFTimer(train_step, in_ms=False)

    mem_track_ = MemTracking()
    mem_track_.start()

    elaps_time = timer.time()

    GB = 1024**3
    print(f"mem_track_.summry(): {mem_track_.summary() / GB:.2f} GB")
    print(f"Time: {elaps_time}")


def bench_meshflow(func, args):
    size = jax.device_count()
    devices = mesh_utils.create_device_mesh((1, size))
    mesh = Mesh(devices, axis_names=('a', 'b'))

    set_device_mesh(mesh)

    opt_strategy = get_opt_strategy(func, *args)

    shard_func = jax.jit(meshflow_shard(func, opt_strategy))

    flatten_args, specs = jax.tree_util.tree_flatten(args)

    flatten_args = shard_module(flatten_args)

    shard_args = jax.tree_util.tree_unflatten(specs, flatten_args)

    def train_step():
        with jax.spmd_mode('allow_all'):
            shard_func(*shard_args)

    # warmup
    train_step()

    timer = MFTimer(train_step, in_ms=False)

    mem_track_ = MemTracking()
    mem_track_.start()

    elaps_time = timer.time()

    GB = 1024**3
    print(f"mem_track_.summry(): {mem_track_.summary() / GB:.2f} GB")
    print(f"Time: {elaps_time}")


def main():
    # setup meshflow
    setup_testing(backend="jax", device="cuda")

    # setup distributed
    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    rank = comm.Get_rank()

    jax.distributed.initialize(coordinator_address="localhost:19705",
                               num_processes=size,
                               process_id=rank,
                               local_device_ids=rank)

    print(
        f"[Rank {rank}], Global Devices: {jax.device_count()}, Local Devices: {jax.local_device_count()}"
    )
    print(jax.devices())

    func, args = get_gat_case()
    args = jax.tree_map(to_shape_array, args)

    bench_meshflow(func, args)


if __name__ == '__main__':
    main()
