# Copyright (c) Alibaba, Inc. and its affiliates.
import os

import gradio as gr

from swift.ui.llm_train.runtime import Runtime
from swift.utils import get_logger

logger = get_logger()


class GRPORuntime(Runtime):

    group = 'llm_grpo'

    locale_dict = {
        'runtime_tab': {
            'label': {
                'zh': '运行时',
                'en': 'Runtime'
            },
        },
        'tb_not_found': {
            'value': {
                'zh': 'tensorboard未安装,使用`pip install tensorboard`进行安装',
                'en': 'tensorboard not found, install it by `pip install tensorboard`',
            }
        },
        'running_cmd': {
            'label': {
                'zh': '运行命令',
                'en': 'Command line'
            },
            'info': {
                'zh': '执行的实际命令',
                'en': 'The actual command'
            }
        },
        'show_running_cmd': {
            'value': {
                'zh': '展示运行命令',
                'en': 'Show running command line'
            },
        },
        'show_sh': {
            'label': {
                'zh': '展示sh命令行',
                'en': 'Show sh command line'
            },
        },
        'cmd_sh': {
            'label': {
                'zh': '训练命令行',
                'en': 'Training command line'
            },
            'info': {
                'zh':
                '如果训练命令行没有展示请再次点击"展示运行命令"，点击下方的"保存训练命令"可以保存sh脚本',
                'en': ('Please press "Show running command line" if the content is none, '
                       'click the "Save training command" below to save the sh script')
            }
        },
        'save_cmd_as_sh': {
            'value': {
                'zh': '保存训练命令',
                'en': 'Save training command'
            }
        },
        'save_cmd_alert': {
            'value': {
                'zh': '训练命令行将被保存在：{}',
                'en': 'The training command line will be saved in: {}'
            }
        },
        'close_cmd_show': {
            'value': {
                'zh': '关闭训练命令展示',
                'en': 'Close training command show'
            }
        },
        'show_log': {
            'value': {
                'zh': '展示运行状态',
                'en': 'Show running status'
            },
        },
        'stop_show_log': {
            'value': {
                'zh': '停止展示运行状态',
                'en': 'Stop showing running status'
            },
        },
        'logging_dir': {
            'label': {
                'zh': '日志路径',
                'en': 'Logging dir'
            },
            'info': {
                'zh': '支持手动传入文件路径',
                'en': 'Support fill custom path in'
            }
        },
        'log': {
            'label': {
                'zh': '日志输出',
                'en': 'Logging content'
            },
            'info': {
                'zh': '如果日志无更新请再次点击"展示运行状态"',
                'en': 'Please press "Show running status" if the log content is not updating'
            }
        },
        'running_tasks': {
            'label': {
                'zh': '运行中任务',
                'en': 'Running Tasks'
            },
            'info': {
                'zh': '运行中的任务（所有的`swift rlhf --rlhf_type grpo`命令）',
                'en': 'All running tasks(started by `swift rlhf --rlhf_type grpo`)'
            }
        },
        'refresh_tasks': {
            'value': {
                'zh': '找回运行时任务',
                'en': 'Find running tasks'
            },
        },
        'kill_task': {
            'value': {
                'zh': '杀死任务',
                'en': 'Kill running task'
            },
        },
        'tb_url': {
            'label': {
                'zh': 'Tensorboard链接',
                'en': 'Tensorboard URL'
            },
            'info': {
                'zh': '仅展示，不可编辑',
                'en': 'Not editable'
            }
        },
        'start_tb': {
            'value': {
                'zh': '打开TensorBoard',
                'en': 'Start TensorBoard'
            },
        },
        'close_tb': {
            'value': {
                'zh': '关闭TensorBoard',
                'en': 'Close TensorBoard'
            },
        },
    }

    @classmethod
    def save_cmd(cls, cmd):
        if len(cmd) > 0:
            cmd_sh, output_dir = cls.cmd_to_sh_format(cmd)
            os.makedirs(output_dir, exist_ok=True)
            sh_file_path = os.path.join(output_dir, 'grpo.sh')
            gr.Info(cls.locale('save_cmd_alert', cls.lang)['value'].format(sh_file_path))
            with open(sh_file_path, 'w', encoding='utf-8') as f:
                f.write(cmd_sh)
