import os
import math
from typing import Text, List, Optional

from datasets import load_dataset


def load_instruction_dataset(dataset_name: Optional[Text] = "tatsu-lab/alpaca_eval", script_args = None) -> List:
    if dataset_name == "anonymized_for_nips/imdb_preference":
        dataset = load_dataset(
            "anonymized_for_nips/imdb_preference", split="test"
        ).select_columns("prompt").rename_columns({"prompt":"raw_prompt"})
    elif dataset_name == "anonymized_for_nips/openai_summarize_comparisons_relabel":
        dataset = load_dataset(
            "anonymized_for_nips/openai_summarize_comparisons_relabel", split="test"
        ).shuffle(seed=42).select_columns("prompt").rename_columns({"prompt":"raw_prompt"}).select(range(1000))
    elif dataset_name == "tatsu-lab/alpaca_eval":
        dataset = load_dataset(
            "tatsu-lab/alpaca_eval", split="eval"
        ).rename_columns({"instruction":"raw_prompt"})
    else:
        raise NotImplementedError

    if script_args.sanity_check:
        dataset = dataset.select(range(20))
    if script_args.world_size != 1:
        split_size = math.ceil(len(dataset) /script_args.world_size)
        dataset = dataset.select(range(
            script_args.rank*split_size, 
            min((script_args.rank+1)*split_size, len(dataset))
        ))
        script_args.output_path = os.path.join(
            script_args.output_path.split(".jsonl")[0], 
            f"{str(script_args.rank).zfill(5)}-of-{str(script_args.world_size).zfill(5)}.jsonl"
        )
    return dataset


def get_local_model_name(model_name):
    if model_name == "meta-llama/Meta-Llama-3-8B-Instruct":
        return "/mnt/petrelfs/share_data/llama3_hf/Meta-Llama-3-8B-Instruct" # ok to use nonlocal
    elif model_name == "meta-llama/Meta-Llama-3-70B-Instruct":
        return "/mnt/petrelfs/share_data/llama3_hf/Meta-Llama-3-70B-Instruct"
    elif model_name == "meta-llama/Llama-2-70b-chat-hf":
        return "/mnt/petrelfs/share_data/llm-safety/models/Llama-2-70b-chat-hf"
    else:
        return model_name
