import argparse
import os
import nltk
import json
import tqdm
import torch
import random
from filelock import FileLock
from transformers.utils import is_offline_mode
import evaluate
import nltk
import vllm
import evaluate
from ..utils import (
    load_hf_lm,
    generate_completions,
    load_hf_tokenizer, 
    load_vllm_model,
)

try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)


def compute_rouge(predictions, targets, rouge=None):
    if rouge is None:
        rouge = evaluate.load("rouge")

    predictions = [pred.strip() for pred in predictions]
    targets = [label.strip() for label in targets]

    # rougeLSum expects newline after each sentence
    predictions = ["\n".join(nltk.sent_tokenize(pred)) for pred in predictions]
    targets = ["\n".join(nltk.sent_tokenize(label)) for label in targets]

    result = rouge.compute(predictions=predictions, references=targets, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    return result


def main(args):
    random.seed(42)

    if args.experiment_id is not None:
        rouge = evaluate.load("rouge", experiment_id=args.experiment_id)
    else:
        rouge = evaluate.load("rouge")

    all_tasks = {}
    for split in ["validation", "test"]:
        task_file = os.path.join(args.data_dir, f"{split}.json")
        all_tasks[split] = []
        with open(task_file) as fin:
            for line in fin:
                all_tasks[split].append(json.loads(line))

    os.makedirs(args.save_dir, exist_ok=True)
    os.makedirs(os.path.join(args.save_dir, "predictions"), exist_ok=True)

    if args.model_name_or_path:
        tokenizer = load_hf_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path,
            use_fast_tokenizer=not args.use_slow_tokenizer,
        )
        if args.use_vllm:
            print("Loading vllm model...")
            model = load_vllm_model(
                args.model_name_or_path, 
                tokenizer_name_or_path=args.tokenizer_name_or_path, 
                use_slow_tokenizer=args.use_slow_tokenizer, 
            )
        else:
            print("Loading model and tokenizer with huggingface...")
            model = load_hf_lm(
                model_name_or_path=args.model_name_or_path, 
                load_in_8bit=args.load_in_8bit, 
                device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
                gptq_model=args.gptq,
            )
            # modify tokenizer if required
            from transformers import GPTNeoXForCausalLM, OPTForCausalLM
            if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM):
                tokenizer.model_max_length = model.config.max_position_embeddings
                print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(model.config.max_position_embeddings))

    # generate outputs
    rouge_score_dict = {}
    for split in tqdm.tqdm(all_tasks.keys(), desc="Evaluating"):
        task_examples = all_tasks[split]
        # prepare prompts    
        tldr_utterance = "\ntl;dr\n"
        prompts = [f"{example['dialogue'].strip()}{tldr_utterance}" for example in task_examples]

        # generate with vllm
        if args.use_vllm:
            sampling_params = vllm.SamplingParams(
                temperature=0,
                max_tokens=128,
            )
            # We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
            generations = model.generate(prompts, sampling_params)
            prompt_to_output = {
                g.prompt: g.outputs[0].text for g in generations
            }
            predictions = [prompt_to_output[prompt].strip() if prompt in prompt_to_output else "" for prompt in prompts]
        # generate with hf model
        else:
            predictions = generate_completions(
                model=model,
                tokenizer=tokenizer,
                prompts=prompts,
                max_new_tokens=128,
                temperature=0,
                batch_size=args.eval_batch_size if args.eval_batch_size else 1,
            )
        # get ground truth
        targets = [example["summary"].strip() for example in task_examples]
        assert len(predictions) == len(targets), "number of predictions and targets are not the same."
        
        with open(os.path.join(args.save_dir, "predictions", f"{split}.jsonl"), "w") as fout:
            for prompt, prediction, target in zip(prompts, predictions, targets):
                fout.write(json.dumps({
                    "prompt": prompt,
                    "prediction": prediction,
                    "target": target,
                }) + "\n")            

        rouge_score_dict[split] = compute_rouge(predictions, targets, rouge=rouge)

        print(f"Split {split} - Rouge: {rouge_score_dict[split]}")

    # save the performance
    with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
        json.dump(rouge_score_dict, fout, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir", 
        type=str, 
        default="data/datasets/sum/dialog_sum/"
    )
    parser.add_argument(
        "--save_dir", 
        type=str, 
        default=None, 
    )
    parser.add_argument(
        "--model_name_or_path", 
        type=str, 
        required=True,
        help="if specified, we will load the model to generate the predictions."
    )
    parser.add_argument(
        "--tokenizer_name_or_path", 
        type=str, 
        default=None, 
        help="if specified, we will load the tokenizer from here."
    )
    parser.add_argument(
        "--use_slow_tokenizer",
        action="store_true",
        help="If given, we will use the slow tokenizer."
    )
    parser.add_argument(
        "--eval_batch_size", 
        type=int, 
        default=1, 
        help="batch size for evaluation."
    )
    parser.add_argument(
        "--load_in_8bit", 
        action="store_true", 
        help="load model in 8bit mode, which will reduce memory and speed up inference."
    )
    parser.add_argument(
        "--gptq", 
        action="store_true", 
        help="If given, we're evaluating a 4-bit quantized GPTQ model."
    )
    parser.add_argument(
        "--use_vllm",
        action="store_true", 
        help="If given, we will use the vllm library, which will likely increase the inference throughput."
    )
    parser.add_argument(
        "--experiment_id",
        type=str,
        default=None,
    )
    args = parser.parse_args()

    if args.save_dir is None:
        args.save_dir = os.path.join(args.model_name_or_path, "eval_results")

    main(args)
