import os
import json
from tqdm import tqdm
import argparse
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset
from torch.utils.data import DataLoader
import torch


def run(inputs, prompts, args):

    print("Loading Model...")
    tokenizer = LlamaTokenizer.from_pretrained(args.model_path, padding_side='left')
    model = LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16)
    
    # tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
    # model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
    model = model.cuda()

    print("Loading Data...")

    def collate_fn(batch):
        batch = [b.values() for b in batch]
        input_text, task_ids, prompt_ids = list(zip(*batch))
        assert len(input_text) == len(task_ids) == len(prompt_ids)
        batch = tokenizer(input_text, padding='longest', truncation=True, return_tensors="pt", max_length=1024)
        batch['task_ids'] = torch.tensor([int(int(task_id)) for task_id in task_ids])
        batch['prompt_ids'] = torch.tensor([int(prompt_id) for prompt_id in prompt_ids])
        return batch

    params = {
        "max_new_tokens": args.max_length,
        "temperature": args.temperature,
    }

    if args.beam_size > 1:
        params["num_beams"] = args.beam_size
        params["num_return_sequences"] = args.beam_size
    

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    json.dump(vars(args), open(os.path.join(args.output_dir, "params.json"), "w"))
    for k, v in vars(args).items():
        print(k+':'+str(v))

    dataset = []
    for idx, task_input in tqdm(inputs):
        for prompt_id, (prefix, suffix) in prompts.items():
            dataset.append({
                "input_text": prefix + task_input + suffix,
                "task_id": idx,
                "prompt_id": prompt_id
            })
    dataset = Dataset.from_list(dataset)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn)

    all_outputs = []

    print("Running...")

    for batch in tqdm(dataloader, "Paraphrasing"):
        task_ids = batch.pop('task_ids')
        prompt_ids = batch.pop('prompt_ids')

        batch = {k: v.cuda() for k, v in batch.items()}
        outputs = model.generate(**batch, **params)
        outputs = outputs[:, batch["input_ids"].shape[-1]:]
        outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for i, (task_id, prompt_id, output) in enumerate(zip(task_ids, prompt_ids, outputs)):
            paraphrase_id = i % args.beam_size
            all_outputs.append({
                "task_id": task_id.item(),
                "prompt_id": prompt_id.item(),
                "paraphrase_id": paraphrase_id,
                "paraphrased_instructions": output.strip()
            })
        
    json.dump(all_outputs, open(os.path.join(args.output_dir, "paraphrased_instructions.json"), "w"), indent=2)
        

def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--output_dir", type=str, default="./results/flan2021_llama")
    argparser.add_argument("--input_dir", type=str, default="./data/Zero_Shot/flan2021.json")

    argparser.add_argument("--model_path", type=str, default="../../../models/models--meta-llama--Llama-2-13b-hf/snapshots/8e5c561236425864e6d2f6a1b41bbb25cc7d2f26")
    argparser.add_argument("--beam_size", type=int, default=3)
    argparser.add_argument("--temperature", type=float, default=0.3)
    argparser.add_argument("--max_length", type=int, default=512)
    argparser.add_argument("--batch_size", type=int, default=2)

    argparser.add_argument("--start", type=int, default=0)
    argparser.add_argument("--end", type=int, default=20)
    argparser.add_argument("--use_prompt", type=int, default=[1, 2, 3], nargs="+")

    args = argparser.parse_args()

    print("Used prompt ids: ", args.use_prompt)

    start = args.start
    end = args.end

    prompts = {
        1: ("Here's an input utterance:\n\n", "\n\n\nNow, your task is to paraphrase the input by only changing the instruction but leaving everything else the same.\nHere's the new utterance:\n\n"),
        2: ("You are given an utterance which is a combination of task instruction and the actual input. Your job is to paraphrase the task instruction and leave the input unchanged. Here's the utterance to be paraphrased:\n\n\n", "\n\n\nNow, generate the new utterance:\n\n\n"),
        3: ("You are provided with the utterance of a specific task and I need you to paraphrase it. The actual input, question, and examples in the task should not be changed. You should only paraphrase the instructions. Task:\n\n\n", "\n\n\nThe paraphrased utterance:\n\n\n")
    }

    remove_prompt = []
    for key in prompts.keys():
        if key not in args.use_prompt:
            remove_prompt.append(key)

    for key in remove_prompt:
        prompts.pop(key)

    inputs = [item["input_text"] for item in json.load(open(args.input_dir, "r"))[start:end]]
    inputs = [(str(start + idx), item) for idx, item in enumerate(inputs)]
    run(inputs, prompts, args)


if __name__ == "__main__":
    main()



