# Copyright (c) Alibaba, Inc. and its affiliates.
from functools import partial
from typing import Dict, Type

import gradio as gr
from packaging import version

from swift.llm.argument.base_args.base_args import get_supported_tuners
from swift.ui.base import BaseUI
from swift.ui.llm_grpo.advanced import GRPOAdvanced
from swift.ui.llm_grpo.dataset import GRPODataset
from swift.ui.llm_grpo.external_rollout import LLMRollout
from swift.ui.llm_grpo.grpo_advanced import GrpoAdvanced
from swift.ui.llm_grpo.hyper import GRPOHyper
from swift.ui.llm_grpo.model import GRPOModel
from swift.ui.llm_grpo.optimizer import GRPOOptimizer
from swift.ui.llm_grpo.quantization import GRPOQuantization
from swift.ui.llm_grpo.report_to import GRPOReportTo
from swift.ui.llm_grpo.reward import Reward
from swift.ui.llm_grpo.rollout import Rollout
from swift.ui.llm_grpo.runtime import GRPORuntime
from swift.ui.llm_grpo.save import GRPOSave
from swift.ui.llm_grpo.tuner import GRPOTuner
from swift.ui.llm_train.llm_train import LLMTrain
from swift.ui.llm_train.runtime import Runtime
from swift.utils import get_device_count, get_logger

logger = get_logger()


class LLMGRPO(LLMTrain):
    group = "llm_grpo"

    sub_ui = [
        GRPOModel,
        GRPODataset,
        Reward,
        GRPORuntime,
        Rollout,
        GRPOSave,
        GRPOTuner,
        GRPOOptimizer,
        GRPOHyper,
        GRPOQuantization,
        GRPOAdvanced,
        GrpoAdvanced,
        GRPOReportTo,
        LLMRollout,
    ]

    locale_dict: Dict[str, Dict] = {
        "llm_grpo": {
            "label": {
                "zh": "LLM GRPO",
                "en": "LLM GRPO",
            }
        },
        "external_alert": {
            "value": {
                "zh": "Err: {} \nRollout模型部署未完成，请检查日志，稍后开始训练！",
                "en": "Err: {} \nRollout model deployment is incomplete, "
                "please check the logs and start training later!",
            }
        },
        "submit_alert": {
            "value": {
                "zh": "任务已开始，请查看tensorboard或日志记录，关闭本页面不影响训练过程",
                "en": "Task started, please check the tensorboard or log file, "
                "closing this page does not affect training",
            }
        },
        "dataset_alert": {
            "value": {
                "zh": "请选择或填入一个数据集",
                "en": "Please input or select a dataset",
            }
        },
        "submit": {"value": {"zh": "🚀 开始训练", "en": "🚀 Begin"}},
        "dry_run": {
            "label": {"zh": "仅生成运行命令", "en": "Dry-run"},
            "info": {
                "zh": "仅生成运行命令，开发者自行运行",
                "en": "Generate run command only, for manually running",
            },
        },
        "gpu_id": {
            "label": {"zh": "选择可用GPU", "en": "Choose GPU"},
            "info": {
                "zh": "选择训练使用的GPU号，如CUDA不可用只能选择CPU",
                "en": "Select GPU to train",
            },
        },
        "train_type": {
            "label": {"zh": "训练方式", "en": "Train type"},
            "info": {"zh": "选择训练的方式", "en": "Select the training type"},
        },
        "seed": {
            "label": {"zh": "随机数种子", "en": "Seed"},
            "info": {"zh": "选择随机数种子", "en": "Select a random seed"},
        },
        "torch_dtype": {
            "label": {"zh": "训练精度", "en": "Training Precision"},
            "info": {"zh": "选择训练精度", "en": "Select the training precision"},
        },
        "envs": {
            "label": {"zh": "环境变量", "en": "Extra env vars"},
        },
        "use_ddp": {
            "label": {"zh": "使用DDP", "en": "Use DDP"},
            "info": {
                "zh": "是否使用数据并行训练",
                "en": "Use Distributed Data Parallel to train",
            },
        },
        "ddp_num": {
            "label": {"zh": "DDP分片数量", "en": "Number of DDP sharding"},
            "info": {
                "zh": "启用多少进程的数据并行",
                "en": "The data parallel size of DDP",
            },
        },
        "use_liger_kernel": {
            "label": {"zh": "使用Liger kernel", "en": "Use Liger kernel"},
            "info": {
                "zh": "Liger kernel可以有效降低显存使用",
                "en": "Liger kernel can reduce memory usage",
            },
        },
        "sequence_parallel_size": {
            "label": {
                "zh": "序列并行大小",
                "en": "Sequence parallel size",
            },
            "info": {
                "zh": "当前支持CPT/SFT/DPO/GRPO",
                "en": "Currently supports CPT/SFT/DPO/GRPO",
            },
        },
        "deepspeed": {
            "label": {
                "zh": "DeepSpeed",
                "en": "DeepSpeed",
            },
            "info": {
                "zh": "可以选择下拉列表，也支持传入路径",
                "en": "Choose from the dropbox or fill in a valid path",
            },
        },
        "resume_checkpoint_alert": {
            "value": {
                "zh": '检测到"args.json"在{}中，将从此检查点开始断点续训',
                "en": 'Detected that "args.json" is in {}, will start breakpoint resume training from this checkpoint',
            }
        },
        "resume_only_model_alert": {
            "value": {
                "zh": '检测到"args.json"在{}中，但未检测到优化器参数，将仅加载模型参数开始断点续训',
                "en": '"args.json" is detected in {}, but optimizer parameters are not detected. '
                "Only model parameters will be loaded to start breakpoint continuation training",
            }
        },
        "more_params": {
            "label": {"zh": "其他高级参数", "en": "Other params"},
            "info": {
                "zh": "以json格式或--xxx xxx命令行格式填入",
                "en": "Fill in with json format or --xxx xxx cmd format",
            },
        },
        "extra_params": {
            "label": {"zh": "其他参数设置", "en": "Extra settings"},
        },
        "train_param": {
            "label": {"zh": "训练参数设置", "en": "Train settings"},
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type["BaseUI"]):
        with gr.TabItem(elem_id="llm_grpo", label=""):
            default_device = "cpu"
            device_count = get_device_count()
            if device_count > 0:
                default_device = "0"
            with gr.Blocks():
                GRPOModel.build_ui(base_tab)
                GRPODataset.build_ui(base_tab)
                Reward.build_ui(base_tab)
                with gr.Accordion(elem_id="train_param", open=True):
                    with gr.Row():
                        gr.Dropdown(
                            elem_id="train_type",
                            scale=4,
                            choices=list(get_supported_tuners()),
                        )
                        gr.Textbox(elem_id="seed", scale=4)
                        gr.Dropdown(elem_id="torch_dtype", scale=4)
                        gr.Checkbox(elem_id="use_liger_kernel", scale=4)
                        gr.Textbox(elem_id="sequence_parallel_size", lines=1, scale=4)
                    with gr.Row():
                        gr.Dropdown(
                            elem_id="gpu_id",
                            multiselect=True,
                            choices=[str(i) for i in range(device_count)] + ["cpu"],
                            value=default_device,
                            scale=8,
                        )
                        gr.Checkbox(elem_id="use_ddp", value=False, scale=4)
                        gr.Textbox(elem_id="ddp_num", value="1", scale=4)
                        gr.Dropdown(
                            elem_id="deepspeed",
                            scale=4,
                            allow_custom_value=True,
                            value=None,
                            choices=[
                                "zero0",
                                "zero1",
                                "zero2",
                                "zero3",
                                "zero2_offload",
                                "zero3_offload",
                            ],
                        )
                GRPOHyper.build_ui(base_tab)
                GRPORuntime.build_ui(base_tab)
                with gr.Row(equal_height=True):
                    gr.Textbox(elem_id="envs", scale=12)
                    gr.Checkbox(elem_id="dry_run", value=False, scale=4)
                    submit = gr.Button(elem_id="submit", scale=4, variant="primary")

                Rollout.build_ui(base_tab)
                LLMRollout.set_lang(cls.lang)
                LLMRollout.build_ui(LLMRollout)
                GRPOTuner.build_ui(base_tab)
                with gr.Accordion(elem_id="extra_params", open=False):
                    with gr.Tabs():
                        GrpoAdvanced.build_ui(base_tab)
                        GRPOAdvanced.build_ui(base_tab)
                        GRPOQuantization.build_ui(base_tab)
                        GRPOSave.build_ui(base_tab)
                        GRPOReportTo.build_ui(base_tab)
                    with gr.Row():
                        gr.Textbox(elem_id="more_params", lines=4, scale=20)

                cls.element("train_type").change(
                    GRPOHyper.update_lr,
                    inputs=[base_tab.element("train_type")],
                    outputs=[cls.element("learning_rate")],
                )

                submit.click(
                    cls.train_local,
                    list(cls.valid_elements().values()),
                    [
                        cls.element("running_cmd"),
                        cls.element("logging_dir"),
                        cls.element("runtime_tab"),
                        cls.element("running_tasks"),
                        cls.element("train_record"),
                    ],
                    queue=True,
                )
                Rollout.element("vllm_mode").change(
                    LLMRollout.external_rollout_display,
                    Rollout.element("vllm_mode"),
                    LLMRollout.element("llm_rollout"),
                )
                LLMRollout.element("rollout").click(
                    LLMRollout.rollout_model,
                    list(LLMRollout.valid_elements().values())
                    + [
                        cls.element("model"),
                        cls.element("model_type"),
                        cls.element("template"),
                    ],
                    [
                        LLMRollout.element("rollout_runtime_tab"),
                        LLMRollout.element("rollout_running_tasks"),
                    ],
                )

                GRPORuntime.element("kill_task").click(
                    GRPORuntime.kill_task,
                    [GRPORuntime.element("running_tasks")],
                    [GRPORuntime.element("running_tasks")]
                    + [GRPORuntime.element("log")]
                    + GRPORuntime.all_plots,
                ).then(
                    GRPORuntime.reset,
                    [],
                    [GRPORuntime.element("logging_dir")]
                    + [GRPOHyper.element("output_dir")],
                )

                base_tab.element("gpu_id").change(
                    cls.update_ddp_num,
                    [base_tab.element("gpu_id"), base_tab.element("use_ddp")],
                    base_tab.element("ddp_num"),
                )
                base_tab.element("use_ddp").change(
                    cls.update_ddp_num,
                    [base_tab.element("gpu_id"), base_tab.element("use_ddp")],
                    base_tab.element("ddp_num"),
                )
                base_tab.element("ddp_num").change(
                    Rollout.update_num_gen,
                    [
                        GRPOHyper.element("per_device_train_batch_size"),
                        GRPOHyper.element("gradient_accumulation_steps"),
                        cls.element("ddp_num"),
                    ],
                    [Rollout.element("num_generations")],
                )
                GRPOHyper.element("gradient_accumulation_steps").change(
                    Rollout.update_num_gen,
                    [
                        GRPOHyper.element("per_device_train_batch_size"),
                        GRPOHyper.element("gradient_accumulation_steps"),
                        cls.element("ddp_num"),
                    ],
                    [Rollout.element("num_generations")],
                )
                GRPOHyper.element("per_device_train_batch_size").change(
                    Rollout.update_num_gen,
                    [
                        GRPOHyper.element("per_device_train_batch_size"),
                        GRPOHyper.element("gradient_accumulation_steps"),
                        cls.element("ddp_num"),
                    ],
                    [Rollout.element("num_generations")],
                )

    @classmethod
    def prepare_sub_to_filter(cls):
        tabs_relation_dict = {
            key: val
            for key, val in zip(
                ["train_type", "optimizer", "vllm_mode"],
                [
                    GRPOTuner.tabs_to_filter,
                    GRPOOptimizer.tabs_to_filter,
                    Rollout.tabs_to_filter,
                ],
            )
        }
        return tabs_relation_dict
