from functools import wraps

import torch

try:
    import torch

    _torch_profiler = None

    def profile_torch(export_trace_filename="export_trace.json",
                      activities=None, **profiler_kwargs):
        """
        A decorator to profile any PyTorch function using torch.profiler and export the trace.

        Args:
            export_trace_filename (str): The filename to export the Chrome Trace.
            activities (list): List of ProfilerActivity to monitor (default: [CPU, CUDA]).
            **profiler_kwargs: Additional arguments to pass to torch.profiler.profile.

        Returns:
            A decorated function with profiling enabled.
        """
        if activities is None:
            activities = [
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ]

        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # Start profiling
                with torch.profiler.profile(
                    activities=activities,
                    profile_memory=True,
                    **profiler_kwargs
                ) as prof:
                    # Execute the original function
                    result = func(*args, **kwargs)
                
                print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
                # Export the trace to a file
                prof.export_chrome_trace(export_trace_filename)
                print(f"Profiler trace exported to {export_trace_filename}")
                return result

            return wrapper
        return decorator
    
    def start_torch_profile(activities=None, **profiler_kwargs):
        """
        Start a PyTorch profiler session.
        Args:
            export_trace_filename (str): The filename to export the Chrome Trace.
            activities (list): List of ProfilerActivity to monitor (default: [CPU, CUDA]).
            **profiler_kwargs: Additional arguments to pass to torch.profiler.profile.
        """
        global _torch_profiler

        if activities is None:
            activities = [
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ]

        # Start the profiler
        _torch_profiler = torch.profiler.profile(
            activities=activities,
            profile_memory=True,
            with_stack=True,  # Enable stack trace recording
            record_shapes=True,  # Record tensor shapes
            **profiler_kwargs
        )
        _torch_profiler.__enter__()  # Manually enter the context manager
        print("PyTorch profiler started.")

    def stop_torch_profile(export_trace_filename="export_trace.json"):
        """
        Stop the PyTorch profiler session and export the trace.
        Args:
            export_trace_filename (str): The filename to export the Chrome Trace.
        """
        global _torch_profiler

        if _torch_profiler is not None:
            _torch_profiler.__exit__(None, None, None)  # Manually exit the context manager

            # Print detailed profiling information
            print("\n=== Top 10 operations by CUDA time ===")
            print(_torch_profiler.key_averages().table(sort_by="cuda_time_total", row_limit=10))
            
            print("\n=== Top 10 operations by CPU time ===")
            print(_torch_profiler.key_averages().table(sort_by="cpu_time_total", row_limit=10))
            
            print("\n=== Memory usage statistics ===")
            print(_torch_profiler.key_averages().table(sort_by="self_cuda_memory_usage", row_limit=10))

            _torch_profiler.export_chrome_trace(export_trace_filename)
            print(f"Profiler trace exported to {export_trace_filename}")
            _torch_profiler = None
        else:
            raise RuntimeError("No active profiler session to stop.")        

except ImportError:
    def profile_torch(*args, **kwargs):
        raise ImportError("PyTorch is not installed. Please install torch to use this function.")
    
    def start_torch_profile(*args, **kwargs):
        raise ImportError("PyTorch is not installed. Please install torch to use this function.")

    def stop_torch_profile(*args, **kwargs):
        raise ImportError("PyTorch is not installed. Please install torch to use this function.")
