import click
from beam.serve import beam_server

from applications.common import IMG_MODELS
from applications.options import MODEL_NAME_OPTION
from applications.processor import ModelProcessorWrapper
from run_options import DEVICE_OPTION


ATTACK_PORT = 5426


class ClassifierModelSampler:
    def __init__(self, model, processor, post_processor):
        self.model = model
        self.processor = processor
        self.post_processor = post_processor

    def predict(self, images):
        processed_images = self.processor(images.to(device=self.model.device))
        return self.post_processor(self.model(processed_images))

    def __call__(self, images):
        return self.predict(images)


@click.command
@MODEL_NAME_OPTION
@DEVICE_OPTION
def main(model_name: str, device: int):
    model_card = IMG_MODELS[model_name]
    sampler = ClassifierModelSampler(
        model_card.model_class.from_pretrained(model_card.model_name).to(device),
        ModelProcessorWrapper(
            model_card.processor_class.from_pretrained(model_card.model_name)
        ),
        model_card.post_processor,
    )
    server = beam_server(
        sampler, backend="waitress", non_blocking=True, port=ATTACK_PORT
    )
    server.run()


if __name__ == "__main__":
    main()
