from typing import Optional
from dataclasses import dataclass, field


@dataclass
class PruningArguments:
    inference_speedup: float = field(
        default=3.0, metadata={"help": "The constraint inference speedup. (Default: 3.0)"}
    )

    use_latency: bool = field(
        default=True, 
        metadata={"help": "A boolean value defines whether to use latency as the metric for inference speedup in pruning. (Default: True)"}
    )

    do_module_reconstruction: bool = field(
        default=False, 
        metadata={"help": "A boolean value defines whether to apply module reconstruction after reconstruction metric becomes true. (Default: True)"}
    )

    num_blocks_to_prune: float = field(
        default=0.05, metadata={"help": "Number of blocks to prune at each pruning step. (Default: 0.05)"}
    )

    num_candidate_blocks: float = field(
        default=0.1, metadata={"help": "Number of candidate blocks to prune at each pruning step in terms of percentage of number of blocks in the model. (Default: 0.1)"}
    )

    max_iterations: int = field(
        default=100, metadata={"help": "Maximum number of pruning iterations to prevent infinite loops. (Default: 100)"}
    )

    alpha: float = field(
        default=1.0, metadata={"help": "The alpha parameter for the HALPE algorithm. (Default: 1.0)"}
    )

    min_damping: float = field(
        default=0.0001, metadata={"help": "The minimum damping parameter for the HALPE algorithm. (Default: 0.0001)"}
    )

    max_damping: float = field(
        default=1.0, metadata={"help": "The maximum damping parameter for the HALPE algorithm. (Default: 1.0)"}
    )

    max_try_head: int = field(
        default=4, metadata={"help": "The maximum number of heads to try to prune at each pruning step. (Default: 4)"}
    )

    max_try_ffn: int = field(
        default=128, metadata={"help": "The maximum number of ffn intermediate dimensions to try to prune at each pruning step. (Default: 128)"}
    )

    use_chunking: bool = field(
        default=False, metadata={"help": "A boolean value defines whether to use chunking for matrix multiplication. (Default: False)"}
    )

    chunk_size: int = field(
        default=32, metadata={"help": "The chunk size for matrix multiplication. (Default: 32)"}
    )

    max_iterative_iterations: int = field(
        default=20, metadata={"help": "Maximum iterations for iterative solver. (Default: 20)"}
    )

    iterative_tolerance: float = field(
        default=1e-6, metadata={"help": "Convergence tolerance for iterative solver. (Default: 1e-6)"}
    )

    conditioned_score_max_chunk_size: int = field(
        default=256, metadata={"help": "The chunk size for conditioned score computation. (Default: 256)"}
    )

    def __post_init__(self):
        if self.num_candidate_blocks > 1.0:
            raise ValueError("num_candidate_blocks must be less than or equal to 1.0")
        if self.num_candidate_blocks < 0.0:
            raise ValueError("num_candidate_blocks must be greater than or equal to 0.0")
        if self.num_blocks_to_prune < 0.0:
            raise ValueError("num_blocks_to_prune must be greater than or equal to 0.0")
        if self.num_blocks_to_prune > 1.0:
            raise ValueError("num_blocks_to_prune must be less than or equal to 1.0")
        if self.inference_speedup < 0.0:
            raise ValueError("inference_speedup must be greater than or equal to 0.0")
        if self.use_latency and self.inference_speedup < 1.0:
            raise ValueError("inference_speedup must be less than or equal to 1.0 when use_latency is true")
        if not self.use_latency and self.inference_speedup > 1.0:
            raise ValueError("inference_speedup must be less than or equal to 1.0 when use_latency is false")