import os
import time
import logging
import gradio as gr
from threading import Thread
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple

import transformers
from transformers.trainer import TRAINING_ARGS_NAME

from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import get_device_count, torch_gc
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar

if TYPE_CHECKING:
    from llmtuner.webui.manager import Manager


class Runner:

    def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None:
        self.manager = manager
        self.demo_mode = demo_mode
        """ Resume """
        self.thread: "Thread" = None
        self.do_train = True
        self.running_data: Dict["Component", Any] = None
        """ State """
        self.aborted = False
        self.running = False
        """ Handler """
        self.logger_handler = LoggerHandler()
        self.logger_handler.setLevel(logging.INFO)
        logging.root.addHandler(self.logger_handler)
        transformers.logging.add_handler(self.logger_handler)

    @property
    def alive(self) -> bool:
        return self.thread is not None

    def set_abort(self) -> None:
        self.aborted = True

    def _initialize(self, data: Dict[Component, Any], do_train: bool, from_preview: bool) -> str:
        get = lambda name: data[self.manager.get_elem_by_name(name)]
        lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
        dataset = get("train.dataset") if do_train else get("eval.dataset")

        if self.running:
            return ALERTS["err_conflict"][lang]

        if not model_name:
            return ALERTS["err_no_model"][lang]

        if not model_path:
            return ALERTS["err_no_path"][lang]

        if len(dataset) == 0:
            return ALERTS["err_no_dataset"][lang]

        if self.demo_mode and (not from_preview):
            return ALERTS["err_demo"][lang]

        if not from_preview and get_device_count() > 1:
            return ALERTS["err_device_count"][lang]

        self.aborted = False
        self.logger_handler.reset()
        self.trainer_callback = LogCallback(self)
        return ""

    def _finalize(self, lang: str, finish_info: str) -> str:
        self.thread = None
        self.running_data = None
        self.running = False
        torch_gc()
        if self.aborted:
            return ALERTS["info_aborted"][lang]
        else:
            return finish_info

    def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
        get = lambda name: data[self.manager.get_elem_by_name(name)]
        user_config = load_config()

        if get("top.adapter_path"):
            adapter_name_or_path = ",".join([
                get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
            for adapter in get("top.adapter_path")])
        else:
            adapter_name_or_path = None

        args = dict(
            stage=TRAINING_STAGES[get("train.training_stage")],
            do_train=True,
            model_name_or_path=get("top.model_path"),
            adapter_name_or_path=adapter_name_or_path,
            cache_dir=user_config.get("cache_dir", None),
            finetuning_type=get("top.finetuning_type"),
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
            template=get("top.template"),
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
            flash_attn=(get("top.booster") == "flash_attn"),
            use_unsloth=(get("top.booster") == "unsloth"),
            dataset_dir=get("train.dataset_dir"),
            dataset=",".join(get("train.dataset")),
            cutoff_len=get("train.cutoff_len"),
            learning_rate=float(get("train.learning_rate")),
            num_train_epochs=float(get("train.num_train_epochs")),
            max_samples=int(get("train.max_samples")),
            per_device_train_batch_size=get("train.batch_size"),
            gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
            lr_scheduler_type=get("train.lr_scheduler_type"),
            max_grad_norm=float(get("train.max_grad_norm")),
            logging_steps=get("train.logging_steps"),
            save_steps=get("train.save_steps"),
            warmup_steps=get("train.warmup_steps"),
            neftune_noise_alpha=get("train.neftune_alpha") or None,
            sft_packing=get("train.sft_packing"),
            upcast_layernorm=get("train.upcast_layernorm"),
            lora_rank=get("train.lora_rank"),
            lora_dropout=get("train.lora_dropout"),
            lora_target=get("train.lora_target") or get_module(get("top.model_name")),
            additional_target=get("train.additional_target") or None,
            create_new_adapter=get("train.create_new_adapter"),
            output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
            fp16=(get("train.compute_type") == "fp16"),
            bf16=(get("train.compute_type") == "bf16")
        )
        args["disable_tqdm"] = True

        if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
            args["create_new_adapter"] = (args["quantization_bit"] is None)

        if args["stage"] == "ppo":
            args["reward_model"] = get_save_dir(
                get("top.model_name"), get("top.finetuning_type"), get("train.reward_model")
            )
            args["reward_model_type"] = "lora" if get("top.finetuning_type") == "lora" else "full"

        if args["stage"] == "dpo":
            args["dpo_beta"] = get("train.dpo_beta")

        if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
            args["val_size"] = get("train.val_size")
            args["evaluation_strategy"] = "steps"
            args["eval_steps"] = get("train.save_steps")
            args["load_best_model_at_end"] = True

        return args

    def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
        get = lambda name: data[self.manager.get_elem_by_name(name)]
        user_config = load_config()

        if get("top.adapter_path"):
            adapter_name_or_path = ",".join([
                get_save_dir(get("top.model_name"), get("top.finetuning_type"), adapter)
            for adapter in get("top.adapter_path")])
        else:
            adapter_name_or_path = None

        args = dict(
            stage="sft",
            model_name_or_path=get("top.model_path"),
            adapter_name_or_path=adapter_name_or_path,
            cache_dir=user_config.get("cache_dir", None),
            finetuning_type=get("top.finetuning_type"),
            quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
            template=get("top.template"),
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
            flash_attn=(get("top.booster") == "flash_attn"),
            use_unsloth=(get("top.booster") == "unsloth"),
            dataset_dir=get("eval.dataset_dir"),
            dataset=",".join(get("eval.dataset")),
            cutoff_len=get("eval.cutoff_len"),
            max_samples=int(get("eval.max_samples")),
            per_device_eval_batch_size=get("eval.batch_size"),
            predict_with_generate=True,
            max_new_tokens=get("eval.max_new_tokens"),
            top_p=get("eval.top_p"),
            temperature=get("eval.temperature"),
            output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
        )

        if get("eval.predict"):
            args["do_predict"] = True
        else:
            args["do_eval"] = True

        return args

    def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        error = self._initialize(data, do_train, from_preview=True)
        if error:
            gr.Warning(error)
            yield error, gr.update(visible=False)
        else:
            args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
            yield gen_cmd(args), gr.update(visible=False)

    def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        error = self._initialize(data, do_train, from_preview=False)
        if error:
            gr.Warning(error)
            yield error, gr.update(visible=False)
        else:
            args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
            run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
            self.do_train, self.running_data = do_train, data
            self.thread = Thread(target=run_exp, kwargs=run_kwargs)
            self.thread.start()
            yield from self.monitor()

    def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        yield from self._preview(data, do_train=True)

    def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        yield from self._preview(data, do_train=False)

    def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        yield from self._launch(data, do_train=True)

    def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        yield from self._launch(data, do_train=False)

    def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
        get = lambda name: self.running_data[self.manager.get_elem_by_name(name)]
        self.running = True
        lang = get("top.lang")
        output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
            "{}.output_dir".format("train" if self.do_train else "eval")
        ))

        while self.thread.is_alive():
            time.sleep(2)
            if self.aborted:
                yield ALERTS["info_aborting"][lang], gr.update(visible=False)
            else:
                yield self.logger_handler.log, update_process_bar(self.trainer_callback)

        if self.do_train:
            if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
                finish_info = ALERTS["info_finished"][lang]
            else:
                finish_info = ALERTS["err_failed"][lang]
        else:
            if os.path.exists(os.path.join(output_dir, "all_results.json")):
                finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
            else:
                finish_info = ALERTS["err_failed"][lang]

        yield self._finalize(lang, finish_info), gr.update(visible=False)
