import argparse
import json
import logging
import os
import sys
from typing import Dict, List, Tuple, Union

import torch
from transformers.utils import is_flash_attn_2_available

import moe_peft
import moe_peft.adapters

# Command Line Arguments
parser = argparse.ArgumentParser(description="MoE-PEFT main program")
parser.add_argument(
    "--base_model", type=str, required=True, help="Path to or name of base model"
)
parser.add_argument(
    "--inference", action="store_true", help="The inference mode (just for test)"
)
parser.add_argument(
    "--evaluate", action="store_true", help="The evaluate mode (just for test)"
)
parser.add_argument(
    "--disable_prompter", action="store_true", help="Disable prompter when inference"
)
parser.add_argument(
    "--load_adapter",
    action="store_true",
    help="Load adapter from file instead of init randomly",
)
parser.add_argument(
    "--disable_adapter", action="store_true", help="Disable the adapter modules"
)
parser.add_argument(
    "--attn_impl", type=str, help="Specify the implementation of attention"
)
parser.add_argument(
    "--sliding_window",
    action="store_true",
    help="Use sliding window attention (requires flash attention)",
)
parser.add_argument(
    "--disable_cache",
    action="store_true",
    help="Disable cache when inference",
)
parser.add_argument(
    "--cache_implementation",
    type=str,
    help="Specify the implementation of cache",
)
parser.add_argument(
    "--fp16", action="store_true", help="Load base model in float16 precision"
)
parser.add_argument(
    "--bf16", action="store_true", help="Load base model in bfloat16 precision"
)
parser.add_argument(
    "--tf32", action="store_true", help="Use tfloat32 instead of float32 if available"
)
parser.add_argument(
    "--load_8bit", action="store_true", help="Load base model with 8bit quantization"
)
parser.add_argument(
    "--load_4bit", action="store_true", help="Load base model with 4bit quantization"
)
parser.add_argument("--device", type=str, help="Specify which GPU to be used")
parser.add_argument(
    "--config", type=str, required=True, help="Path to finetune configuration"
)
parser.add_argument(
    "--seed", type=int, default=42, help="Random seed in integer, default is 42"
)
parser.add_argument(
    "--dir", type=str, default=".", help="Path to read or save checkpoints"
)
parser.add_argument("--disable_log", action="store_true", help="Disable logging")
parser.add_argument("--log_file", type=str, help="Save log to specific file")
parser.add_argument(
    "--verbose", action="store_true", help="Show extra informations such as parameters"
)
parser.add_argument(
    "--overwrite",
    action="store_true",
    help="Overwrite adapter model when older one existed",
)
parser.add_argument("--debug", action="store_true", help="Enabling debugging mode")
parser.add_argument(
    "--deterministic",
    action="store_true",
    help="Use deterministic algorithms to improve the reproducibility",
)

args = parser.parse_args()


def query_yes_no(question, default="no"):
    valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
    if default is None:
        prompt = " [y/n] "
    elif default == "yes":
        prompt = " [Y/n] "
    elif default == "no":
        prompt = " [y/N] "
    else:
        raise ValueError("invalid default answer: '%s'" % default)

    while True:
        sys.stdout.write(question + prompt)
        choice = input().lower()
        if default is not None and choice == "":
            return valid[default]
        elif choice in valid:
            return valid[choice]
        else:
            sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")


def load_base_model() -> Tuple[moe_peft.Tokenizer, moe_peft.LLMModel]:
    logging.info("Initializing pre-trained model.")
    model = moe_peft.LLMModel.from_pretrained(
        name_or_path=args.base_model,
        device=args.device,
        attn_impl=args.attn_impl,
        use_sliding_window=args.sliding_window,
        bits=(8 if args.load_8bit else (4 if args.load_4bit else None)),
        load_dtype=(
            torch.bfloat16
            if args.bf16
            else (torch.float16 if args.fp16 else torch.float32)
        ),
    )

    tokenizer = moe_peft.Tokenizer(args.base_model)

    return tokenizer, model


def init_adapter_config(
    config: Dict[str, any],
    llm_model: moe_peft.LLMModel,
) -> List[Union[moe_peft.GenerateConfig, moe_peft.TrainConfig]]:
    config_list = []

    if config["cutoff_len"] == -1:
        config["cutoff_len"] = llm_model.max_seq_len_
        logging.info(f"Setting cutoff_len to {llm_model.max_seq_len_} automatically.")

    for lora_config in config["lora"]:
        adapter_name = lora_config["name"]
        adapter_path = f"{args.dir}{os.sep}{adapter_name}"
        if not args.load_adapter and os.path.exists(adapter_path):
            if args.overwrite:
                logging.warning(
                    f"Overwriting existed adapter model file: {adapter_path}"
                )
            elif not query_yes_no(
                f"Existed adapter model file detected: {adapter_path}\n" + "Overwrite?"
            ):
                logging.info("User canceled training due to file conflict.")
                exit(0)

        if args.load_adapter:
            llm_model.load_adapter(adapter_path, adapter_name)
        else:
            llm_model.init_adapter(moe_peft.adapters.lora_config_factory(lora_config))

        if args.inference:
            config_class = moe_peft.GenerateConfig(adapter_name=adapter_name)
            if not args.disable_prompter:
                config_class.prompt_template = lora_config.get("prompt", None)
            config_list.append(config_class)
        elif args.evaluate:
            config_list.extend(moe_peft.EvaluateConfig.from_config(lora_config))
        else:
            config_list.append(moe_peft.TrainConfig.from_config(lora_config))

        if args.verbose:
            logging.info(config_list[-1].__dict__)

    return config_list


def inference_callback(cur_pos, outputs):
    print(f"POSITION: {cur_pos}")
    for adapter_name, output in outputs.items():
        print(f"{adapter_name} OUTPUT: {output[0]}")


def inference(
    model: moe_peft.LLMModel,
    tokenizer: moe_peft.Tokenizer,
    configs: List[moe_peft.GenerateConfig],
    concurrent_jobs: int,
):
    while True:
        input_raw = input("INPUT WITHOUT PROMPT: ")
        if input_raw == "QUIT":
            return
        for config in configs:
            config.prompts = [input_raw]
        callback = None if args.disable_log else inference_callback
        outputs = moe_peft.generate(
            model,
            tokenizer,
            configs,
            max_gen_len=128,
            use_cache=not args.disable_cache,
            concurrent_jobs=concurrent_jobs,
            cache_implementation=args.cache_implementation,
            stream_callback=callback,
        )
        print(f"\n{'='*10}\n")
        print(f"PROMPT: {input_raw}")
        for adapter_name, output in outputs.items():
            print(f"{adapter_name} OUTPUT:")
            print(output[0])
        print(f"\n{'='*10}\n")


# Main Function
if __name__ == "__main__":
    if args.debug:
        torch.autograd.set_detect_anomaly(True)

    if args.inference or args.evaluate:
        args.load_adapter = True
        inference_mode = True
    else:
        inference_mode = False

    moe_peft.setup_logging("INFO", args.log_file)

    moe_peft_executor = moe_peft.executor

    if not moe_peft_executor.check_available():
        exit(-1)

    if args.attn_impl is None:
        if (
            inference_mode
            and moe_peft_executor.device_name() == "cuda"
            and is_flash_attn_2_available()
        ):
            args.attn_impl = "flash_attn"
        else:
            args.attn_impl = "eager"

    if args.device is None:
        args.device = moe_peft.executor.default_device_name()

    moe_peft_executor.use_deterministic_algorithms(args.deterministic)
    moe_peft_executor.allow_tf32(args.tf32)
    moe_peft_executor.manual_seed(args.seed)

    with open(args.config, "r", encoding="utf8") as fp:
        config = json.load(fp)

    tokenizer, model = load_base_model()
    adapters = init_adapter_config(config, model)

    moe_peft_executor.empty_cache()

    if os.getenv("MOE_PEFT_EVALUATE_MODE") is None:
        logging.info("Using efficient operators.")
    else:
        logging.info("Using deterministic operators.")

    if args.inference:
        inference(
            model=model,
            tokenizer=tokenizer,
            configs=adapters,
            concurrent_jobs=config.get("inference_lora_simultaneously_num", 2),
        )
    elif args.evaluate:
        moe_peft.evaluate(
            model=model,
            tokenizer=tokenizer,
            configs=adapters,
            max_concurrent_jobs=config.get("eval_lora_simultaneously_num", None),
            retrying_steps=config.get("eval_rollback_retrying_steps", 20),
            max_seq_len=config["cutoff_len"],
            save_file=config.get("evaluate_result", None),
        )
    else:
        moe_peft.train(
            model=model,
            tokenizer=tokenizer,
            configs=adapters,
            max_concurrent_jobs=config.get("train_lora_simultaneously_num", None),
            strategy=config["train_strategy"],
            cutoff_len=config["cutoff_len"],
            save_step=config["save_step"],
            save_dir=args.dir,
        )
