import torch
import click
from beam.distributed import RayClient
from fsspec.registry import default

from applications.attack_run import attack, DATASETS
from applications.common import MODELS
from applications.experiment import main
from applications.options import NUM_OF_WINDOWS_OPTION, WIDTH_OPTION, HEIGHT_OPTION
from run_options import CONCURRENCY_OPTION


@click.group()
def experiment():
    pass


@experiment.command("llm")
@click.option(
    "--model_name", type=click.Choice(list(MODELS.keys())), default="code_queen_chat"
)
@click.option("--loss", type=str, default="speed")
@click.option("-p", "--parameters", type=str, default="soft_tune")
@click.option("--func", type=str, default="sort")
@click.option("--device", type=int, default=0 if torch.cuda.is_available() else None)
def llm_experiment(model_name: str, loss: str, parameters: str, func: str, device):
    main(model_name, loss, parameters, func, device)


@experiment.command("attack")
@click.option("--server", type=str, default="dsisarit04")
@CONCURRENCY_OPTION
@HEIGHT_OPTION
@WIDTH_OPTION
@click.option(
    "--dataset_name", type=click.Choice(list(DATASETS.keys())), default="imagenet"
)
@NUM_OF_WINDOWS_OPTION
@click.option("--top_check", type=int, default=5)
@click.option("--classification", type=int, default=None)
@click.option("--d_param", type=click.Tuple([str, str]), multiple=True)
def attack_with_ray(
    server: str,
    concurrency: int,
    height: int,
    width: int,
    dataset_name: str,
    num_of_windows: int,
    top_check,
    classification,
    d_param,
):
    attack(
        server,
        concurrency,
        height,
        width,
        dataset_name,
        num_of_windows,
        top_check,
        classification,
        d_param,
    )


if __name__ == "__main__":
    RayClient()
    experiment()
