# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Type

import gradio as gr

from swift.ui.base import BaseUI


class Hyper(BaseUI):

    group = 'llm_train'

    locale_dict = {
        'hyper_param': {
            'label': {
                'zh': '超参数设置(更多参数->其他参数设置)',
                'en': 'Hyper settings(more params->Extra settings)',
            },
        },
        'per_device_train_batch_size': {
            'label': {
                'zh': '训练batch size',
                'en': 'Train batch size',
            },
            'info': {
                'zh': '训练的batch size',
                'en': 'Set the train batch size',
            }
        },
        'per_device_eval_batch_size': {
            'label': {
                'zh': '验证batch size',
                'en': 'Val batch size',
            },
            'info': {
                'zh': '验证的batch size',
                'en': 'Set the val batch size',
            }
        },
        'learning_rate': {
            'label': {
                'zh': '学习率',
                'en': 'Learning rate',
            },
            'info': {
                'zh': '设置学习率',
                'en': 'Set the learning rate',
            }
        },
        'eval_steps': {
            'label': {
                'zh': '交叉验证步数',
                'en': 'Eval steps',
            },
            'info': {
                'zh': '设置每隔多少步数进行一次验证',
                'en': 'Set the step interval to validate',
            }
        },
        'num_train_epochs': {
            'label': {
                'zh': '数据集迭代轮次',
                'en': 'Train epoch',
            },
            'info': {
                'zh': '设置对数据集训练多少轮次',
                'en': 'Set the max train epoch',
            }
        },
        'gradient_accumulation_steps': {
            'label': {
                'zh': '梯度累计步数',
                'en': 'Gradient accumulation steps',
            },
            'info': {
                'zh': '设置梯度累计步数以减小显存占用',
                'en': 'Set the gradient accumulation steps',
            }
        },
        'attn_impl': {
            'label': {
                'zh': 'Flash Attention类型',
                'en': 'Flash Attention Type',
            },
        },
        'neftune_noise_alpha': {
            'label': {
                'zh': 'NEFTune噪声系数',
                'en': 'NEFTune noise coefficient'
            },
            'info': {
                'zh': '使用NEFTune提升训练效果, 一般设置为5或者10',
                'en': 'Use NEFTune to improve performance, normally the value should be 5 or 10'
            }
        },
        'save_steps': {
            'label': {
                'zh': '存储步数',
                'en': 'Save steps',
            },
            'info': {
                'zh': '设置每个多少步数进行存储',
                'en': 'Set the save steps',
            }
        },
        'output_dir': {
            'label': {
                'zh': '存储目录',
                'en': 'The output dir',
            },
            'info': {
                'zh': '设置输出模型存储在哪个文件夹下',
                'en': 'Set the output folder',
            }
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='hyper_param', open=False):
            with gr.Blocks():
                with gr.Row():
                    gr.Slider(elem_id='per_device_train_batch_size', minimum=1, maximum=256, step=2, scale=20)
                    gr.Slider(elem_id='per_device_eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
                    gr.Textbox(elem_id='learning_rate', value='1e-4', lines=1, scale=20)
                    gr.Textbox(elem_id='num_train_epochs', lines=1, scale=20)
                    gr.Slider(
                        elem_id='gradient_accumulation_steps',
                        minimum=1,
                        maximum=256,
                        step=2,
                        value=1 if cls.group == 'llm_grpo' else 16,
                        scale=20)
                with gr.Row():
                    gr.Textbox(elem_id='eval_steps', lines=1, value='500', scale=20)
                    gr.Textbox(elem_id='save_steps', value='500', lines=1, scale=20)
                    gr.Textbox(elem_id='output_dir', scale=20)
                    gr.Dropdown(elem_id='attn_impl', scale=20, value='flash_attn')
                    gr.Slider(elem_id='neftune_noise_alpha', minimum=0.0, maximum=20.0, step=0.5, scale=20)

    @staticmethod
    def update_lr(sft_type):
        if sft_type == 'full':
            return 1e-5
        else:
            return 1e-4
