# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List

from swift.llm import ExportArguments, PtEngine, RequestConfig, Template, prepare_model_template
from swift.utils import get_logger

logger = get_logger()


def replace_and_concat(template: 'Template', template_list: List, placeholder: str, keyword: str):
    final_str = ''
    for t in template_list:
        if isinstance(t, str):
            final_str += t.replace(placeholder, keyword)
        elif isinstance(t, (tuple, list)):
            if isinstance(t[0], int):
                final_str += template.tokenizer.decode(t)
            else:
                for attr in t:
                    if attr == 'bos_token_id':
                        final_str += template.tokenizer.bos_token
                    elif attr == 'eos_token_id':
                        final_str += template.tokenizer.eos_token
                    else:
                        raise ValueError(f'Unknown token: {attr}')
    return final_str


def export_to_ollama(args: ExportArguments):
    args.device_map = 'meta'  # Accelerate load speed.
    logger.info('Exporting to ollama:')
    os.makedirs(args.output_dir, exist_ok=True)
    model, template = prepare_model_template(args)
    pt_engine = PtEngine.from_model_template(model, template)
    logger.info(f'Using model_dir: {pt_engine.model_dir}')
    template_meta = template.template_meta
    with open(os.path.join(args.output_dir, 'Modelfile'), 'w', encoding='utf-8') as f:
        f.write(f'FROM {pt_engine.model_dir}\n')
        f.write(f'TEMPLATE """{{{{ if .System }}}}'
                f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}'
                f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}'
                f'{{{{ end }}}}')
        f.write(f'{{{{ if .Prompt }}}}'
                f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
                f'{{{{ end }}}}')
        f.write('{{ .Response }}')
        f.write(replace_and_concat(template, template_meta.suffix, '', '') + '"""\n')
        f.write(f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n')

        request_config = RequestConfig(
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            repetition_penalty=args.repetition_penalty)
        generation_config = pt_engine._prepare_generation_config(request_config)
        pt_engine._add_stop_words(generation_config, request_config, template.template_meta)
        for stop_word in generation_config.stop_words:
            f.write(f'PARAMETER stop "{stop_word}"\n')
        f.write(f'PARAMETER temperature {generation_config.temperature}\n')
        f.write(f'PARAMETER top_k {generation_config.top_k}\n')
        f.write(f'PARAMETER top_p {generation_config.top_p}\n')
        f.write(f'PARAMETER repeat_penalty {generation_config.repetition_penalty}\n')

    logger.info('Save Modelfile done, you can start ollama by:')
    logger.info('> ollama serve')
    logger.info('In another terminal:')
    logger.info('> ollama create my-custom-model ' f'-f {os.path.join(args.output_dir, "Modelfile")}')
    logger.info('> ollama run my-custom-model')
