import argparse
import os
import sys
from functools import reduce

import numpy as np
import torch
import yaml
from datasets import Dataset
from joblib import Parallel, delayed
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import AutoPeftModelForCausalLM
from tqdm import tqdm

sys.path.append(os.getcwd())

from src.utils.dataset import get_dataset


def generate(config, prompts, device="cuda:0"):
    model_path = os.path.join(DATA_DIR, "data/models", DATASET, "sft-models", config["model_directory"])
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.padding_side = "left"
    generation_config = GenerationConfig(**config["generation_kwargs"], pad_token_id=tokenizer.pad_token_id)
    model_class = AutoPeftModelForCausalLM if "lora_config" in config else AutoModelForCausalLM
    model = model_class.from_pretrained(model_path, device_map=device, torch_dtype=torch.float16).to(device)

    results = []
    batch_size = config["gen_batch_size"]
    for b in tqdm(range(0, len(prompts), batch_size)):
        batch = prompts[b : b + batch_size]
        embeddings = tokenizer(batch, padding=True, truncation=True, return_tensors="pt", pad_to_multiple_of=8).to(
            device
        )
        answers = model.generate(**embeddings, generation_config=generation_config)
        results += tokenizer.batch_decode(answers[:, embeddings["input_ids"].shape[1] :], skip_special_tokens=True)
    return results


if __name__ == "__main__":
    DATA_DIR = os.getenv("DATA_DIR", ".")
    DATASET = os.getenv("DATASET", "tldr")  
    COMBINE_MODELS = os.getenv("COMBINE_MODELS", "f")  
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--combine_models",
        type=str,
        default="f",
        help="Whether to mix and match different model outputs or generate twice from one",
    )
    args = parser.parse_args()
    dataset = Dataset.from_list(get_dataset(DATASET, None, "text"))
    outputs = [[] for _ in range(len(dataset))]
    used_models = [[] for _ in range(len(dataset))]

    if "t" in args.combine_models or "t" in COMBINE_MODELS:
        config_dirs = ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "gpt-j-6B"]
        unique_outputs = np.arange(len(config_dirs) + 1)
        choices = np.zeros((len(dataset), 2), dtype=int)
        choices[:, 0] = np.random.choice(unique_outputs[:-1], len(dataset), replace=True, p=[0.35, 0.22, 0.2, 0.2, 0.03])
        choices[:, 1] = np.clip(choices[:, 0] + np.ceil(np.random.exponential(size=len(dataset))), 0, len(config_dirs)) # Sample a model to compare against, preference to a more similar size
    else:
        config_dirs = [os.getenv("CONFIG_DIR", "gpt2")]
        unique_outputs = [0, 0]
        choices = np.zeros((len(dataset), 2), dtype=int)

    for i in unique_outputs:
        index_i = np.where((choices == i).any(axis=1))[0]
        dataset_i = dataset.select(index_i)
        if i + 1 == len(unique_outputs):  # Add human outputs
            print(f"Getting {len(dataset_i)} human outputs")
            outputs_i = dataset_i["label"]
            used_model = "human"
        else:
            config = yaml.load(open(os.path.join("configs", config_dirs[i], "sft.yaml")), yaml.Loader)
            print(f"Generating {len(dataset_i)} outputs from {config_dirs[i]}")
            devices = torch.cuda.device_count()
            chunk_size = int(np.ceil(len(dataset_i) / devices))
            stacked_outputs_i = Parallel(n_jobs=devices)(
                delayed(generate)(
                    config, dataset_i["prompt"][i * chunk_size : (i + 1) * chunk_size], device=f"cuda:{i}"
                )
                for i in range(devices)
            )
            outputs_i = [o for g in stacked_outputs_i for o in g]
            used_model = config_dirs[i]

        for idx, o in zip(index_i, outputs_i):
            outputs[idx].append(o)
            used_models[idx].append(used_model)

    label_idx = np.random.randint(0, 2, len(dataset))

    label = [o[i] for o, i in zip(outputs, label_idx)]
    label_model = [u[i] for u, i in zip(used_models, label_idx)]

    output_idx = np.abs(label_idx - 1)
    output = [o[i] for o, i in zip(outputs, output_idx)]
    output_model = [u[i] for u, i in zip(used_models, output_idx)]

    generated_dataset = dataset.remove_columns("label")
    generated_dataset = reduce(
        lambda acc, val: acc.add_column(val[0], val[1]),
        [("label", label), ("output", output), ("label_model", label_model), ("output_model", output_model)],
        generated_dataset,
    )
    generated_dataset = generated_dataset.filter(lambda row: row["output"] != "" and row["label"] != "")
    if "t" in args.combine_models or "t" in COMBINE_MODELS:
        generated_dataset.to_json(os.path.join(DATA_DIR, "data/datasets/", DATASET, "sft/outputs.json"), orient="records")
    else:
        generated_dataset.to_json(os.path.join(DATA_DIR, "data/datasets/", DATASET, config["output_directory"], "outputs.json"), orient="records")
