import os

import fire

from src.modeling import LLaMA
from src.modeling_30b import LLaMA30B
from src.modeling_args import LoraModelArgs, ModelArgs
from src.modeling_hf import LLaMAHF
from src.modeling_lora import LoraLLaMA
from src.modeling_lora_30b import LoraLLaMA30B
from src.modeling_lora_hf import LoraLLaMAHF
from src.tokenizer import Tokenizer
from src.evaluator import DistributedEvaluator
import src.utils as utils


def evaluate_for_multiple_files(
        task: str,
        folder_path: str,
        log_dir: str,
        model_type: str,
        evaluator: DistributedEvaluator,
        batch_size: int,
        max_seq_len: int,
        prompt: str
):
    for filename in os.listdir(folder_path):
        if filename.endswith('.jsonl'):
            label_file = os.path.join(folder_path, filename)
            datalist = evaluator.generate(
                label_file=label_file,
                batch_size=batch_size,
                max_seq_len=max_seq_len,
                prompt=prompt
            )
            evaluator.evaluating(
                task=task,
                datalist=datalist,
                output_file=os.path.join(log_dir, f"{model_type}-{filename}")
            )


def main(
        ckpt_dir: str,
        task: str,
        label_file: str,
        model_type: str = "7B",
        prompt: str = "",
        log_dir: str = "log",
        max_seq_len: int = 512,
        max_batch_size: int = 128,
        lora_rank: int = 256,
        tokenizer_path: str = None,
        config_file: str = None,
        seed: int = None
):
    tokenizer_path = 'config/tokenizer.model' if tokenizer_path is None else tokenizer_path
    config_file = f"config/{model_type}/params.json" if config_file is None else config_file
    seed = 1 if seed is None else seed
    local_rank, world_size = utils.setup_model_parallel(
        use_float16=True, seed=seed
    )

    if lora_rank > 0:  # using lora
        params = LoraModelArgs(
            max_seq_len=max_seq_len,
            local_rank=local_rank,
            world_size=world_size,
            r=lora_rank
        ).from_json(config_file)
        if model_type == "30B":
            model = LoraLLaMA30B(params)
        elif 'orca' in model_type.lower():
            model = LoraLLaMAHF(params)
        else:
            model = LoraLLaMA(params)
    else:  # not using lora
        params = ModelArgs(
            max_seq_len=max_seq_len,
            local_rank=local_rank,
            world_size=world_size
        ).from_json(config_file)
        if model_type == "30B":
            model = LLaMA30B(params)
        elif 'orca' in model_type.lower():
            model = LLaMAHF(params)
        else:
            model = LLaMA(params)

    evaluator = DistributedEvaluator(
        model=model,
        tokenizer=Tokenizer(tokenizer_path)
    )
    model.load(ckpt_dir)
    if local_rank == 0:
        os.makedirs(log_dir, exist_ok=True)
    utils.barrier()

    if task in ['BBH', "AGIEval", "MMLU"]:
        evaluate_for_multiple_files(
            task=task,
            folder_path=label_file,
            log_dir=log_dir,
            model_type=model_type,
            evaluator=evaluator,
            batch_size=max_batch_size,
            max_seq_len=max_seq_len,
            prompt=prompt
        )
        return

    datalist = evaluator.generate(
        label_file=label_file,
        batch_size=max_batch_size,
        max_seq_len=max_seq_len,
        prompt=prompt
    )
    evaluator.evaluating(
        task=task,
        datalist=datalist,
        output_file=os.path.join(log_dir, f"{task}-{model_type}")
    )


if __name__ == '__main__':
    fire.Fire(main)
