import sys

import random
import numpy as np
import pandas as pd
import torch

from transformers import set_seed
from datasets import Dataset

from vllm import LLM
from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    pd.util.testing.N = seed
    torch.manual_seed(seed)
    set_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)


def save_df_dataset(df, save_path):
    dataset = Dataset.from_pandas(df)

    dataset.save_to_disk(save_path)


def generate_responses(model_name_or_path, prompts, sampling_params, policies=[""]):
    # llm = LLM(model=model_name_or_path, tokenizer_mode="slow")
    llm = LLM(model=model_name_or_path)
    policy_prompts = [np.random.choice(policies) + prompt for prompt in prompts]

    # Generation
    outputs = llm.generate(policy_prompts, sampling_params)
    generations = []
    error_count = 0
    for output in outputs:
        try:
            generated_text = output.outputs[0].text
        except (IndexError, AttributeError):
            generated_text = None
            error_count += 1
        generations.append(generated_text)

    print(f"Generated all responses. #error = {error_count}.")
    # Reset computing resource
    destroy_model_parallel()
    del llm
    torch.cuda.synchronize()

    return generations


if __name__ == '__main__':
    globals()[sys.argv[1]]()
