#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable

import time
import sys
from rich.console import Console
from functools import wraps
import logging

logger = logging.getLogger(__name__)

def simple_timer(
    prnt: bool = False, 
    logging: bool = False,
    precision: int = 4,
    lpad: int = 35,
) -> Callable:
    r"""
    A decorator that measures execution time of a function.

    Args:
        prnt (bool): If True, prints the time taken.
        logging (bool): If True, logs the time taken.
        precision (int): Decimal places to format the time.
        lpad (int): Left padding for the log message.

    Returns:
        callable: The decorated function.
    """
    def decorator(func):
        @wraps(func)
        def exec_time(*args, **kwargs):
            start = time.perf_counter()
            result = func(*args, **kwargs)
            end = time.perf_counter()
            duration = end - start
            exec_time.time_taken = duration

            fn = f"time_taken({func.__name__})"
            dur = f"{duration:.{precision}f}"

            if prnt:
                print(f"{func.__name__} took {dur} seconds")
            if logging:
                logger.debug(f"{fn.ljust(lpad)} >>> seconds={dur}")
            return result
        exec_time.time_taken = None
        return exec_time
    return decorator


def status(
    status: str = "Running...",
    spinner_style: str = "bold spring_green3",
    spinner: str = "aesthetic",
    show_timer: bool = True,
    show_func_name: bool = False,
    name: str = "track"
) -> Callable:
    import threading
    from types import SimpleNamespace 
    from rich.status import Status

    r"""
    A decorator that shows a spinner with a live timer in the status 
    message. Uses self.status_spinner=True/False to enable/disable.

    Args:
        status (str): Message to display in the status.
        spinner_style (str): Style of the spinner.
        spinner (str): Type of spinner to use.
        show_timer (bool): Whether to display the elapsed time.
        show_func_name (bool): Whether to show the function name.
        name (str): The name of the tracking object that will be 
            available as a global variable inside the decorated function. 
            This allows users to set and update arbitrary fields (e.g., 
            loss, accuracy) using dot notation, such as 
            `progress.loss = 0.1`. Defaults to "track".
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            self = args[0] if args else None
            is_verbose = getattr(self, "status_spinner", False)
            console = Console(log_path=False, soft_wrap=True, quiet=not is_verbose)
            
            start_time = time.time()
            elapsed = [0]
            tracking_obj = SimpleNamespace()

            # Inject 'track' into global module namespace
            sys.modules[func.__module__].__dict__[name] = tracking_obj
            if not is_verbose:
                return func(*args, **kwargs)             

            status_obj = Status(
                status=f"{status} [dim](0.0s)[/dim]",
                console=console,
                spinner=spinner,
                spinner_style=spinner_style
            )

            def update_timer():
                while status_obj._live:
                    elapsed[0] = time.time() - start_time

                    tracked_vars = [
                        f"{k}={v}" for k, v in vars(tracking_obj).items()
                        if v is not None
                    ]
                    var_str = " ".join(tracked_vars)

                    timer_str = f" [dim]({elapsed[0]:.1f}s)[/]" if show_timer else ""

                    status_message = (
                        (f"[dim bold spring_green3]{func.__name__}[/] " if show_func_name else "") +
                        f"[bold spring_green3]{status}[/] " +
                        f"[bold sea_green1]{var_str}[/]{timer_str}"
                    )

                    status_obj.update(status_message)
                    time.sleep(0.1)

            with status_obj:
                if is_verbose:
                    threading.Thread(target=update_timer, daemon=True).start()
                result = func(*args, **kwargs)
                elapsed[0] = time.time() - start_time
                status_obj.stop()
            return result
        return wrapper
    return decorator