from dataclasses import dataclass


@dataclass
class LinearSolveArgs:
    """
    Arguments for solving linear problem on the gradient of the network (using automatic differentiation).

    Args:
        mode            Automatic differentiation mode, either 'forward' or 'reverse'
        driver          Linear solver type, some information:
                        driver can be 'gels', 'gelsy', 'gelsd' or 'gelss' for CPU, 'gels' for CUDA
                        - 'gels' is QR solve, is fast but needs full rank
                        - 'gelsy' is QR with pivoting, more stable but slower than 'gels', still in general should be faster than gelsd or gelss
                        - 'gelsd' is SVD via divide&conquer, best for ill-conditioned but expensive, default on many library
                        - 'gelss' is full SVD, most stable, very slow and might be memory intensive
        rcond           Regularization
        device          'cpu' or 'cuda'
    """
    mode: str = "forward"
    driver: str = "gelsd"
    rcond: float = 1e-10
    device: str = "cpu"
    batch_size: int | None = None
    pin_memory: bool = False                # Enable if loading from CPU memory to CUDA memory, by default set to default value to provide a fair comparison
    empty_cuda_cache: bool = False          # Enforces to empty the cuda cache for minimal cuda memory usage, by default set to default value to provide a fair comparison

    def __post_init__(self):
        assert self.mode in {'forward', 'reverse'}, "Unknown automatic differentiation mode: use 'forward' or 'reverse'"
        assert self.driver in {'gels', 'gelsy', 'gelsd', 'gelss'}, "Unknown lstsq driver: you can use 'gelsd' for cpu 'gels' for cuda"
        assert self.rcond > 0, "Regularization is negative."
        assert self.device in {'cpu', 'cuda'}, "Unknown device"
        if self.device == "cuda":
            assert self.driver == "gels", "Only driver 'gels' is supported by 'cuda' device"
        if not self.batch_size is None: assert self.batch_size > 0
