import argparse
import csv
import json
import os
from functools import partial

import torch
import transformers
import yaml
from accelerate.utils import find_executable_batch_size
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from vllm import SamplingParams

import multimodalmodels
from api_models import api_models_map
from baselines.model_utils import (
    get_template,
    load_model_and_tokenizer,
    load_vllm_model,
)

# Set this to disable warning messages in the generation mode.
transformers.utils.logging.set_verbosity_error()


def parse_args():
    parser = argparse.ArgumentParser(description="Running red teaming with baseline methods.")
    parser.add_argument(
        "--model_name", type=str, help="The name of the model in the models config file"
    )
    parser.add_argument(
        "--models_config_file",
        type=str,
        default="./configs/model_configs/models.yaml",
        help="The path to the config file with model hyperparameters",
    )
    parser.add_argument(
        "--behaviors_path",
        type=str,
        default="./data/behavior_datasets/harmbench_behaviors_text_all.csv",
        help="The path to the behaviors file",
    )
    parser.add_argument(
        "--test_cases_path",
        type=str,
        help="The path to the test cases file to generate completions for",
    )
    parser.add_argument("--save_path", type=str, help="The path for saving completions")
    parser.add_argument(
        "--max_new_tokens", type=int, default=256, help="Max new tokens for completions"
    )
    parser.add_argument(
        "--generate_with_vllm",
        action="store_true",
        help="Whether to generate completions with vLLM (if applicable)",
    )
    parser.add_argument(
        "--incremental_update",
        action="store_true",
        help="Whether to incrementally update completions or generate a new completions file from scratch",
    )
    args = parser.parse_args()
    return args


def main():
    # ========== load arguments and config ========== #
    args = parse_args()
    print(args)

    # Load model config file
    config_file = (
        f"configs/model_configs/models.yaml"
        if not args.models_config_file
        else args.models_config_file
    )
    with open(config_file) as file:
        model_configs = yaml.full_load(file)

    num_gpus = model_configs[args.model_name].get("num_gpus", 0)
    # check num gpus available to see if greater than num_gpus in config
    num_gpus_available = torch.cuda.device_count()
    if num_gpus_available != num_gpus:
        print(
            f"Warning: num_gpus in config ({num_gpus}) does not match num_gpus available ({num_gpus_available}). Using {num_gpus_available} GPUs."
        )
        num_gpus = num_gpus_available
    model_config = model_configs[args.model_name]["model"]
    model_config["num_gpus"] = num_gpus

    print("model_config", model_config)

    # ========== generate test cases ========== #
    # load behaviors csv
    with open(args.behaviors_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        behaviors = list(reader)[1:]  # ignore the header

    with open(args.test_cases_path, "r") as f:
        test_cases_data = json.load(f)

    generation_function = load_generation_function(
        model_config,
        args.max_new_tokens,
        test_cases_path=args.test_cases_path,
        generate_with_vllm=args.generate_with_vllm,
    )

    returned_data = {}

    if args.incremental_update and os.path.exists(args.save_path):
        with open(args.save_path, "r") as file:
            returned_data = json.load(file)

        # Remove test cases that have already been generated
        new_test_cases_data = {}
        new_returned_data = {}
        for behavior_id, _test_cases in test_cases_data.items():
            # Check if behavior_id is in returned_data and completions have been generated for all test cases
            if behavior_id in returned_data and len(returned_data[behavior_id]) == len(_test_cases):
                continue
            print(behavior_id, len(_test_cases))
            new_test_cases_data[behavior_id] = _test_cases
            new_returned_data[behavior_id] = (
                []
            )  # If not all completions are present, we will regenerate all completions
        test_cases_data = new_test_cases_data

    # ==== Flatten dict to array ======
    test_cases = []
    for behavior_id, _test_cases in test_cases_data.items():
        test_cases.extend([dict(behavior_id=behavior_id, test_case=t) for t in _test_cases])

    if len(test_cases) == 0:
        print("No test cases to generate completions for")
        return
    else:
        print(f"Generating completions for {len(test_cases)} test cases")

    # ==== Generate ====
    print("Generating completions...")
    generations = generation_function(test_cases=[t["test_case"] for t in test_cases])
    print("Done")

    # ==== Unflatten into dict of {behavior_id: [{test_case: t, generation: g}, ..], ...}
    for t, g in zip(test_cases, generations):
        returned_data.setdefault(t["behavior_id"], [])
        returned_data[t["behavior_id"]].append(dict(test_case=t["test_case"], generation=g))

    # Create directories for save_path
    print(f"Saving completions to {args.save_path}...")
    (
        os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
        if os.path.dirname(args.save_path)
        else None
    )
    with open(args.save_path, "w") as file:
        json.dump(returned_data, file, indent=4)
    print("Saved")


def _vllm_generate(model, test_cases, template, **generation_kwargs):
    inputs = [template["prompt"].format(instruction=s) for s in test_cases]
    outputs = model.generate(inputs, **generation_kwargs)
    generations = [o.outputs[0].text.strip() for o in outputs]
    return generations


def _hf_generate_with_batching(model, tokenizer, test_cases, template, **generation_kwargs):
    @find_executable_batch_size(starting_batch_size=10)
    def inner_generation_loop(batch_size):
        nonlocal model, tokenizer, test_cases, template, generation_kwargs
        generations = []
        for i in tqdm(range(0, len(test_cases), batch_size)):
            batched_test_cases = test_cases[i : i + batch_size]
            inputs = [template["prompt"].format(instruction=s) for s in batched_test_cases]
            print("inputs", inputs)
            inputs = tokenizer(inputs, return_tensors="pt", padding=True)
            inputs = inputs.to(model.device)
            with torch.no_grad():
                outputs = model.generate(
                    inputs=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    **generation_kwargs,
                ).cpu()
            generated_tokens = outputs[:, inputs["input_ids"].shape[1] :]
            batch_generations = [
                tokenizer.decode(o, skip_special_tokens=True).strip() for o in generated_tokens
            ]
            print("batch_generations", batch_generations)
            generations.extend(batch_generations)
        return generations

    return inner_generation_loop()

def notebook_load_generation_function(
    config_file, model_name, max_new_tokens
):
    with open(config_file) as file:
        model_configs = yaml.full_load(file)

    num_gpus = model_configs[model_name].get("num_gpus", 0)
    # check num gpus available to see if greater than num_gpus in config
    num_gpus_available = torch.cuda.device_count()
    if num_gpus_available != num_gpus:
        print(
            f"Warning: num_gpus in config ({num_gpus}) does not match num_gpus available ({num_gpus_available}). Using {num_gpus_available} GPUs."
        )
        num_gpus = num_gpus_available
    model_config = model_configs[model_name]["model"]
    model_config["num_gpus"] = num_gpus
    model_config["dtype"] = "float16"

    generation_function = load_generation_function(
        model_config,
        max_new_tokens,
        test_cases_path=None,
        generate_with_vllm=False,
    )
    return generation_function

def notebook_main(test_cases, generation_function):
    returned_data = []

    # ==== Generate ====
    print("Generating completions...")
    generations = generation_function(test_cases=test_cases)
    print("Done")

    for t, g in zip(test_cases, generations):
        returned_data.append(dict(test_case=t, generation=g))

    return returned_data

def _api_model_generate(model, test_cases, **generation_kwargs):
    # MultiModal test cases are in saved in pair of [[img, text], ...]
    if isinstance(test_cases[0], (tuple, list)):
        images, prompts = zip(*test_cases)
        return model.generate(prompts=list(prompts), images=list(images), **generation_kwargs)
    return model.generate(prompts=test_cases, **generation_kwargs)


def load_generation_function(
    model_config, max_new_tokens, test_cases_path, generate_with_vllm=False
):
    model_name_or_path = model_config["model_name_or_path"]

    if model_name_or_path in dir(multimodalmodels):
        model = getattr(multimodalmodels, model_name_or_path)()
        generation_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False, num_beams=1)
        image_dir = os.path.join(os.path.dirname(test_cases_path), "images")
        return partial(model.generate, image_dir=image_dir, **generation_kwargs)
    if model := api_models_map(**model_config):
        generation_kwargs = dict(max_new_tokens=max_new_tokens, temperature=0.0, use_tqdm=True)
        return partial(_api_model_generate, model=model, **generation_kwargs)
    elif generate_with_vllm:
        print("Using vLLM generation")
        sp = SamplingParams(temperature=0, max_tokens=max_new_tokens)
        model = load_vllm_model(**model_config)
        generation_kwargs = dict(sampling_params=sp, use_tqdm=True)
        TEMPLATE = get_template(
            model_name_or_path, chat_template=model_config.get("chat_template", None)
        )
        return partial(_vllm_generate, model=model, template=TEMPLATE, **generation_kwargs)
    else:
        print("Using HF generation")
        model, tokenizer = load_model_and_tokenizer(**model_config)
        generation_kwargs = dict(max_new_tokens=max_new_tokens, do_sample=False)
        TEMPLATE = get_template(
            model_name_or_path, chat_template=model_config.get("chat_template", None)
        )
        return partial(
            _hf_generate_with_batching,
            model=model,
            tokenizer=tokenizer,
            template=TEMPLATE,
            **generation_kwargs,
        )


if __name__ == "__main__":
    main()
