from typing import Type

import gradio as gr

from swift.ui.base import BaseUI


class Hyper(BaseUI):

    group = 'llm_train'

    locale_dict = {
        'hyper_param': {
            'label': {
                'zh': '超参数',
                'en': 'Hyper settings',
            },
        },
        'batch_size': {
            'label': {
                'zh': '训练batch size',
                'en': 'Train batch size',
            },
            'info': {
                'zh': '训练的batch size',
                'en': 'Set the train batch size',
            }
        },
        'eval_batch_size': {
            'label': {
                'zh': '验证batch size',
                'en': 'Val batch size',
            },
            'info': {
                'zh': '验证的batch size',
                'en': 'Set the val batch size',
            }
        },
        'learning_rate': {
            'label': {
                'zh': '学习率',
                'en': 'Learning rate',
            },
            'info': {
                'zh': '设置学习率',
                'en': 'Set the learning rate',
            }
        },
        'eval_steps': {
            'label': {
                'zh': '交叉验证步数',
                'en': 'Eval steps',
            },
            'info': {
                'zh': '设置每隔多少步数进行一次验证',
                'en': 'Set the step interval to validate',
            }
        },
        'num_train_epochs': {
            'label': {
                'zh': '数据集迭代轮次',
                'en': 'Train epoch',
            },
            'info': {
                'zh': '设置对数据集训练多少轮次',
                'en': 'Set the max train epoch',
            }
        },
        'max_steps': {
            'label': {
                'zh': '最大迭代步数',
                'en': 'Max steps',
            },
            'info': {
                'zh':
                '设置最大迭代步数，该值如果大于零则数据集迭代次数不生效',
                'en':
                'Set the max steps, if the value > 0 then num_train_epochs has no effects',
            }
        },
        'gradient_accumulation_steps': {
            'label': {
                'zh': '梯度累计步数',
                'en': 'Gradient accumulation steps',
            },
            'info': {
                'zh': '设置梯度累计步数以减小显存占用',
                'en': 'Set the gradient accumulation steps',
            }
        },
        'max_grad_norm': {
            'label': {
                'zh': '梯度裁剪',
                'en': 'Max grad norm',
            },
            'info': {
                'zh': '设置梯度裁剪',
                'en': 'Set the max grad norm',
            }
        },
        'predict_with_generate': {
            'label': {
                'zh': '使用生成指标代替loss',
                'en': 'Use generate metric instead of loss',
            },
            'info': {
                'zh': '验证时使用generate/Rouge代替loss',
                'en': 'Use model.generate/Rouge instead of loss',
            }
        },
        'use_flash_attn': {
            'label': {
                'zh': '使用Flash Attention',
                'en': 'Use Flash Attention',
            },
            'info': {
                'zh': '使用Flash Attention减小显存占用',
                'en': 'Use Flash Attention to reduce memory',
            }
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='hyper_param', open=True):
            with gr.Blocks():
                with gr.Row():
                    gr.Slider(
                        elem_id='batch_size',
                        minimum=1,
                        maximum=256,
                        step=2,
                        scale=20)
                    learning_rate = gr.Textbox(
                        elem_id='learning_rate',
                        value='1e-4',
                        lines=1,
                        scale=20)
                    gr.Textbox(elem_id='num_train_epochs', lines=1, scale=20)
                    gr.Textbox(elem_id='max_steps', lines=1, scale=20)
                    gr.Slider(
                        elem_id='gradient_accumulation_steps',
                        minimum=1,
                        maximum=256,
                        step=2,
                        value=16,
                        scale=20)
                with gr.Row():
                    gr.Slider(
                        elem_id='eval_batch_size',
                        minimum=1,
                        maximum=256,
                        step=2,
                        scale=20)
                    gr.Textbox(
                        elem_id='eval_steps', lines=1, value='500', scale=20)
                    gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
                    gr.Checkbox(elem_id='predict_with_generate', scale=20)
                    gr.Checkbox(elem_id='use_flash_attn', scale=20)

            def update_lr(sft_type):
                if sft_type == 'full':
                    return 1e-5
                else:
                    return 1e-4

            base_tab.element('sft_type').change(
                update_lr,
                inputs=[base_tab.element('sft_type')],
                outputs=[learning_rate])
