import os
import torch

should_compile = os.getenv("ENABLE_TORCH_COMPILE", "True").lower() == "true"

def vticompile(*dargs, **dkwargs):
    """
    A decorator that conditionally applies torch.compile.
    If ENABLE_TORCH_COMPILE is True, then:
      - When used without arguments, it simply compiles the function.
      - When used with arguments, those are forwarded to torch.compile.
    If compiling is disabled or torch.compile is unavailable, it returns the original function.
    """
    # If used directly without extra arguments: @vticompile
    if len(dargs) == 1 and callable(dargs[0]) and not dkwargs:
        func = dargs[0]
        if should_compile:
            try:
                import torch._dynamo
                torch._dynamo.config.capture_scalar_outputs = True
                return torch.compile(func)
            except AttributeError as e:
                print("torch.compile not available:", e)
                return func
        else:
            return func
    else:
        # If used with arguments: @vticompile(opt1=value1, ...)
        def decorator(func):
            if should_compile:
                try:
                    import torch._dynamo
                    torch._dynamo.config.capture_scalar_outputs = True
                    return torch.compile(func, *dargs, **dkwargs)
                except AttributeError as e:
                    print("torch.compile not available:", e)
                    return func
            else:
                return func
        return decorator



#import os
#import torch
#
#should_compile = os.getenv("ENABLE_TORCH_COMPILE", "True") == "True"
#
#if should_compile:
#    try:
#        import torch._dynamo
#
#        # Enable scalar output capture
#        torch._dynamo.config.capture_scalar_outputs = True
#        vticompile = torch.compile
#    except AttributeError:
#
#        def vticompile(f, *args, **kwargs):
#            return f
#else:
#    def vticompile(f, *args, **kwargs):
#            return f
#
#
#def vticompile(fn):
#    if should_compile:
#        return torch.compile(fn)
#    else:
#        return fn

