from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch

import pandas as pd
import os
import multiprocessing
import time
from tqdm import tqdm

torch.cuda.set_device(0)


model_tag_to_id = {
"meta-llama/Meta-Llama-3-8B-Instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-70B-Instruct": "meta-llama/Meta-Llama-3-70B-Instruct",
"google/gemma-2-27b-it": "google/gemma-2-27b-it",
"mistralai/Codestral-22B-v0.1": "mistralai/Codestral-22B-v0.1",
"google/gemma-2-9b-it": "google/gemma-2-9b-it",
"mistralai/Mistral-7B-Instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
"HuggingFaceH4/zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta",
}

data_file = "data/alfworld/sft/test.json"
eval_dataset = load_dataset(
    "json",
    data_files=data_file,
    split="train",
)
eval_dataset = eval_dataset.select(range(14)) # debugging
eval_batch_size = 1

outdir = f"outputs/{time.strftime('%Y%m%d-%H%M%S')}"
os.makedirs(outdir, exist_ok=True)
outfile = f"{outdir}/model_generation"

data_to_save = {}

for model_tag, model_id in model_tag_to_id.items():
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(model_id, truncation=True, padding=True)
    tokenizer.truncation_side = "left"
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def process(row):
        row["prompt"] = tokenizer.apply_chat_template(
            row["prompt"], tokenize=False, add_generation_prompt=True
        )
        #row["response"] = tokenizer.apply_chat_template(row["response"], tokenize=False)
        return row

    import pdb; pdb.set_trace()
    eval_dataset_proc = eval_dataset.map(
        process,
        num_proc=multiprocessing.cpu_count(),
    )

    responses = []
    for eidx in tqdm(range(0, len(eval_dataset_proc), eval_batch_size)):
        batch = eval_dataset_proc[eidx : eidx + eval_batch_size]
        import pdb; pdb.set_trace()
        tokenized_inputs = tokenizer(
            batch["prompt"], return_tensors="pt", padding=True, truncation=True
        ).to(model.device)

        outputs = model.generate(
            tokenized_inputs["input_ids"],
            attention_mask=tokenized_inputs["attention_mask"],
            max_new_tokens=256,
            eos_token_id=[
                tokenizer.eos_token_id,
                tokenizer.convert_tokens_to_ids("<|eot_id|>"),
            ],
            temperature=0.3,
        )

        batch_responses = [
            tokenizer.decode(
                output[tokenized_inputs["input_ids"].shape[-1] :],
                skip_special_tokens=True,
            )
            for output in outputs
        ]

        responses.extend(batch_responses)
        
    data_to_save["prompt"] = [datapoint["prompt"][0]["content"] for datapoint in eval_dataset]
    data_to_save["oracle"] = [datapoint["response"][0]["content"] for datapoint in eval_dataset]
    
    data_to_save[model_tag] = responses

    df = pd.DataFrame(data_to_save)
    df.to_csv(f"{outfile}.csv", index=False)
    df.to_json(f"{outfile}.json", orient="records", lines=False)
    print(f"Saved outputs to {outfile}")
    
    del model
    torch.cuda.empty_cache()