import os
from typing import Literal

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def infer_multilora(infer_request: 'InferRequest', infer_backend: Literal['vllm', 'pt']):
    # Dynamic LoRA
    adapter_path = safe_snapshot_download('swift/test_lora')
    adapter_path2 = safe_snapshot_download('swift/test_lora2')
    args = BaseArguments.from_pretrained(adapter_path)
    if infer_backend == 'pt':
        engine = PtEngine(args.model)
    elif infer_backend == 'vllm':
        from swift.llm import VllmEngine
        engine = VllmEngine(args.model, enable_lora=True, max_loras=1, max_lora_rank=16)
    template = get_template(args.template, engine.processor, args.system)
    request_config = RequestConfig(max_tokens=512, temperature=0)
    adapter_request = AdapterRequest('lora1', adapter_path)
    adapter_request2 = AdapterRequest('lora2', adapter_path2)

    # use lora
    resp_list = engine.infer([infer_request], request_config, template=template, adapter_request=adapter_request)
    response = resp_list[0].choices[0].message.content
    print(f'lora1-response: {response}')
    # origin model
    resp_list = engine.infer([infer_request], request_config)
    response = resp_list[0].choices[0].message.content
    print(f'response: {response}')
    # use lora
    resp_list = engine.infer([infer_request], request_config, template=template, adapter_request=adapter_request2)
    response = resp_list[0].choices[0].message.content
    print(f'lora2-response: {response}')


def infer_lora(infer_request: 'InferRequest'):
    request_config = RequestConfig(max_tokens=512, temperature=0)
    adapter_path = safe_snapshot_download('swift/test_lora')
    args = BaseArguments.from_pretrained(adapter_path)
    # method1
    # engine = PtEngine(args.model, adapters=[adapter_path])
    # template = get_template(args.template, engine.processor, args.system)
    # engine.default_template = template

    # method2
    # model, processor = args.get_model_processor()
    # model = Swift.from_pretrained(model, adapter_path)
    # template = args.get_template(processor)
    # engine = PtEngine.from_model_template(model, template)

    # method3
    model, tokenizer = get_model_tokenizer(args.model)
    model = Swift.from_pretrained(model, adapter_path)
    template = get_template(args.template, tokenizer, args.system)
    engine = PtEngine.from_model_template(model, template)

    resp_list = engine.infer([infer_request], request_config)
    response = resp_list[0].choices[0].message.content
    print(f'lora-response: {response}')


if __name__ == '__main__':
    from swift.llm import (PtEngine, RequestConfig, AdapterRequest, get_template, BaseArguments, InferRequest,
                           safe_snapshot_download, get_model_tokenizer)
    from swift.tuners import Swift
    infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
    # infer_lora(infer_request)
    infer_multilora(infer_request, 'pt')
