import time
import logging

import rich
import jax
from jax import random

from meshflow.jax import get_opt_strategy, set_device_mesh
from meshflow.utils.testing import setup_testing, JaxMockDeviceMesh
from meshflow.jax.model.resnet import ResNet18
from meshflow.jax.model import GPTBlock, GPTConfig, GPTSimple
from meshflow.jax.model import GATLayer

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                    datefmt='%m/%d %H:%M:%S',
                    level=logging.DEBUG)


def test_gat():
    model = GATLayer(in_features=512, out_features=512)
    key1, key2, key3 = random.split(random.PRNGKey(0), num=3)
    h = random.normal(key1, (1024, 512))  # Dummy input data
    adj = random.normal(key2, (1024, 1024))  # Dummy input data
    variables = model.init(key3, h, adj)  # Initialization call
    params = variables['params']

    def forward(params, h, adj):
        return model.apply({'params': params}, h, adj)

    opt_strategy = get_opt_strategy(forward, params, h, adj)

    out_ = forward(params, h, adj)
    print(f"out_.shape: {out_.shape}")


def test_resnet():
    model = ResNet18(num_classes=1000)

    key1, key2 = random.split(random.PRNGKey(0))
    x = random.normal(key1, (16, 224, 224, 3))  # Dummy input data
    variables = model.init(key2, x)  # Initialization call
    params, batch_stats = variables['params'], variables['batch_stats']
    jax.tree_util.tree_map(lambda x: x.shape, params)  # Checking output shapes

    def forward(params, x):
        return model.apply({
            'params': params,
            'batch_stats': batch_stats
        },
                           x,
                           mutable=['batch_stats'])

    opt_strategy = get_opt_strategy(forward, params, x)

    out_, batch_stats = forward(params, x)

    print(f"out_.shape: {out_.shape}")


def test_gpt():
    gpt_config = GPTConfig(num_layers=4)
    model = GPTBlock(gpt_config)

    root_key = jax.random.PRNGKey(seed=0)
    main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)
    # x = random.normal(key1, (16, 224, 224, 3))  # Dummy input data
    x = random.normal(main_key, (4, 256, 768))  # Dummy input data
    variables = model.init(params_key, x, deterministic=True)
    params = variables['params']
    jax.tree_util.tree_map(lambda x: x.shape, params)  # Checking output shapes

    def forward(params, x):
        return model.apply({'params': params},
                           x,
                           deterministic=False,
                           rngs={'dropout': dropout_key})

    opt_strategy = get_opt_strategy(forward, params, x)

    out_ = forward(params, x)
    print(f"out_.shape: {out_.shape}")


if __name__ == '__main__':
    setup_testing(backend="jax", device="cpu")

    mesh = JaxMockDeviceMesh(1, 4)
    set_device_mesh(mesh)

    test_resnet()
