# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Literal, Optional

from swift.llm.utils import update_generation_config_eos_token
from swift.plugin import extra_tuners
from swift.tuners import Swift
from swift.utils import get_logger
from ..utils import Messages

logger = get_logger()


@dataclass
class InferCliState:
    # None: use default-system. '': not use system.
    system: Optional[str] = None
    messages: Messages = field(default_factory=list)  # not including system

    images: List[str] = field(default_factory=list)
    audios: List[str] = field(default_factory=list)
    videos: List[str] = field(default_factory=list)

    multiline_mode: bool = False
    input_system: bool = False

    def clear(self):
        self.messages = []
        self.images = []
        self.audios = []
        self.videos = []

    def add_query(self, query: str) -> None:
        role = 'user'
        if query.startswith('tool:'):
            role = 'tool'
            query = query[len('tool:'):]
        self.messages.append({'role': role, 'content': query})

    def add_response(self, response: str) -> None:
        self.messages.append({'role': 'assistant', 'content': response})

    def to_dict(self):
        infer_state = deepcopy(self)
        if infer_state.system is not None:
            infer_state.messages.insert(0, {'role': 'system', 'content': infer_state.system})
        return {
            'messages': infer_state.messages,
            'images': infer_state.images,
            'audios': infer_state.audios,
            'videos': infer_state.videos
        }

    def input_mm_data(self) -> None:

        def _input_mm_file(mm_type: Literal['image', 'video', 'audio']) -> str:
            a_an = 'an' if mm_type[0] in {'i', 'a'} else 'a'
            return input(f'Input {a_an} {mm_type} path or URL <<< ')

        mm_types = ['image', 'video', 'audio']
        query = self.messages[-1]['content']
        mm_tags = re.findall('|'.join(f'<{mm_type}>' for mm_type in mm_types), query)
        # mm_tag -> mm_type/mm_key
        mm_mapping = {f'<{mm_type}>': (mm_type, f'{mm_type}s') for mm_type in mm_types}
        for mm_tag in mm_tags:
            mm_type, mm_key = mm_mapping[mm_tag]
            mm_val = getattr(self, mm_key)
            mm_val.append(_input_mm_file(mm_type))

    @staticmethod
    def _input_multiline(prompt: str) -> str:
        query = ''
        stop_words = '#\n'
        while True:
            text = f'{input(prompt)}\n'
            prompt = ''
            if text.endswith(stop_words):
                query += text[:-len(stop_words)]
                break
            query += text
        return query

    def input_text(self) -> str:
        if self.multiline_mode:
            addi_prompt = '[MS]' if self.input_system else '[M]'
            text = InferCliState._input_multiline(f'<<<{addi_prompt} ')
        else:
            addi_prompt = '[S]' if self.input_system else ''
            text = input(f'<<<{addi_prompt} ')
        return text

    def check_query(self, query: str) -> Optional[str]:
        query_std = query.strip().lower()
        if self.input_system:
            if query == 'default-system':
                self.system = None
            else:
                self.system = query
            self.input_system = False
            query_std = 'clear'
        if query_std == 'clear':
            self.clear()
            return
        if query_std == '':
            return
        if query_std == 'reset-system':
            self.input_system = True
            return
        if query_std == 'multi-line':
            self.multiline_mode = True
            logger.info('End multi-line input with `#`.')
            logger.info('Input `single-line` to switch to single-line input mode.')
            return
        if query_std == 'single-line':
            self.multiline_mode = False
            return
        return query


def prepare_adapter(args, model, adapters=None):
    if args.tuner_backend == 'unsloth':
        if args.model_meta.is_multimodal:
            from unsloth import FastVisionModel as UnslothModel
        else:
            from unsloth import FastLanguageModel as UnslothModel
        UnslothModel.for_inference(model)
        return model
    if args.train_type in extra_tuners:
        tuner = extra_tuners[args.train_type]
    else:
        tuner = Swift
    # compat deploy
    adapters = adapters or args.adapters
    for adapter in adapters:
        model = tuner.from_pretrained(model, adapter)
    if args.train_type == 'bone':
        # Bone has a problem of float32 matmul with bloat16 in `peft==0.14.0`
        model.to(model.dtype)
    return model


def prepare_model_template(args, **kwargs):
    model, processor = args.get_model_processor(**kwargs)
    model = prepare_adapter(args, model)
    template = args.get_template(processor)
    update_generation_config_eos_token(model.generation_config, template)
    return model, template
