import os

import meshflow as mf

ALL_PLATFORM = ["torch", "jax", "tvm"]


class MockDeviceMesh:

    def __init(self):
        pass


class TorchMockDeviceMesh(MockDeviceMesh):

    def __init__(self, *arg):
        super().__init__()
        self.shape = tuple(arg)

    def size(self, i):
        return self.shape[i]


class JaxDeviceID:

    def __init__(self, *arg):
        self.shape = tuple(arg)


class JaxMockDeviceMesh(MockDeviceMesh):

    def __init__(self, *arg):
        super().__init__()
        self.device_ids = JaxDeviceID(*arg)


def setup_testing(backend, device="cpu"):
    os.environ["MESHFLOW_DEVICE"] = device
    if backend == "jax":
        import jax
        os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
        jax.config.update('jax_platforms', device)
    elif backend == "torch":
        import torch
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    mf.platform.init_backend(backend)


def assert_partial_func_equal(func1, func2):
    assert func1.args == func2.args
    assert func1.keywords == func2.keywords
    assert func1.func.__name__ == func2.func.__name__