import os.path
from typing import Type

import gradio as gr
import json

from swift.llm import MODEL_MAPPING, TEMPLATE_MAPPING, ModelType
from swift.ui.base import BaseUI
from swift.ui.llm_infer.generate import Generate


class Model(BaseUI):

    llm_train = 'llm_infer'

    sub_ui = [Generate]

    is_inference = os.environ.get('USE_INFERENCE') == '1' or os.environ.get(
        'MODELSCOPE_ENVIRONMENT') == 'studio'

    locale_dict = {
        'checkpoint': {
            'value': {
                'zh': '训练后的模型',
                'en': 'Trained model'
            }
        },
        'model_type': {
            'label': {
                'zh': '选择模型',
                'en': 'Select Model'
            },
            'info': {
                'zh': 'SWIFT已支持的模型名称',
                'en': 'Base model supported by SWIFT'
            }
        },
        'load_checkpoint': {
            'value': {
                'zh': '加载模型' if is_inference else '部署模型',
                'en': 'Load model' if is_inference else 'Deploy model',
            }
        },
        'model_id_or_path': {
            'label': {
                'zh': '模型id或路径',
                'en': 'Model id or path'
            },
            'info': {
                'zh': '实际的模型id',
                'en': 'The actual model id or model path'
            }
        },
        'template_type': {
            'label': {
                'zh': '模型Prompt模板类型',
                'en': 'Prompt template type'
            },
            'info': {
                'zh': '选择匹配模型的Prompt模板',
                'en': 'Choose the template type of the model'
            }
        },
        'system': {
            'label': {
                'zh': 'system字段',
                'en': 'system'
            },
            'info': {
                'zh': 'system字段支持在加载模型后修改',
                'en': 'system can be modified after the model weights loaded'
            }
        },
        'more_params': {
            'label': {
                'zh': '更多参数',
                'en': 'More params'
            },
            'info': {
                'zh': '以json格式填入',
                'en': 'Fill in with json format'
            }
        },
        'reset': {
            'value': {
                'zh': '恢复初始值',
                'en': 'Reset to default'
            },
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Row():
            model_type = gr.Dropdown(
                elem_id='model_type',
                choices=[base_tab.locale('checkpoint', cls.lang)['value']]
                + ModelType.get_model_name_list() + cls.get_custom_name_list(),
                value=base_tab.locale('checkpoint', cls.lang)['value'],
                scale=20)
            model_id_or_path = gr.Textbox(
                elem_id='model_id_or_path',
                lines=1,
                scale=20,
                interactive=True)
            template_type = gr.Dropdown(
                elem_id='template_type',
                choices=list(TEMPLATE_MAPPING.keys()) + ['AUTO'],
                scale=20)
            reset_btn = gr.Button(elem_id='reset', scale=2)
            model_state = gr.State({})
        with gr.Row():
            system = gr.Textbox(elem_id='system', lines=4, scale=20)
        Generate.build_ui(base_tab)
        with gr.Row():
            gr.Textbox(elem_id='more_params', lines=1, scale=20)
            gr.Button(elem_id='load_checkpoint', scale=2, variant='primary')

        def update_input_model(choice, model_state=None):
            if choice == base_tab.locale('checkpoint', cls.lang)['value']:
                if model_state and choice in model_state:
                    model_id_or_path = model_state[choice]
                else:
                    model_id_or_path = None
                default_system = None
                template = None
            else:
                if model_state and choice in model_state:
                    model_id_or_path = model_state[choice]
                else:
                    model_id_or_path = MODEL_MAPPING[choice][
                        'model_id_or_path']
                default_system = getattr(
                    TEMPLATE_MAPPING[MODEL_MAPPING[choice]['template']]
                    ['template'], 'default_system', None)
                template = MODEL_MAPPING[choice]['template']
            return model_id_or_path, default_system, template

        def update_model_id_or_path(model_type, path, system, template_type,
                                    model_state):
            if not path or not os.path.exists(path):
                return system, template_type, model_state
            local_path = os.path.join(path, 'sft_args.json')
            if not os.path.exists(local_path):
                default_system = getattr(
                    TEMPLATE_MAPPING[MODEL_MAPPING[model_type]['template']]
                    ['template'], 'default_system', None)
                template = MODEL_MAPPING[model_type]['template']
                return default_system, template, model_state

            with open(local_path, 'r') as f:
                sft_args = json.load(f)
            base_model_type = sft_args['model_type']
            system = getattr(
                TEMPLATE_MAPPING[MODEL_MAPPING[base_model_type]['template']]
                ['template'], 'default_system', None)
            model_state[model_type] = path
            return sft_args['system'] or system, sft_args[
                'template_type'], model_state

        model_type.change(
            update_input_model,
            inputs=[model_type, model_state],
            outputs=[model_id_or_path, system, template_type])

        model_id_or_path.change(
            update_model_id_or_path,
            inputs=[
                model_type, model_id_or_path, system, template_type,
                model_state
            ],
            outputs=[system, template_type, model_state])

        def reset(model_type):
            model_id_or_path, default_system, template = update_input_model(
                model_type)
            return model_id_or_path, default_system, template, {}

        reset_btn.click(
            reset,
            inputs=[model_type],
            outputs=[model_id_or_path, system, template_type, model_state])
