# NOTE: this script loads in a pre-trained LLM and runs inference over binary dataset in one of two ways: either compute the distortion of dataset with respect to LLM over all possible compresssed methods, or uses an off-the-shelf prompt compression algorithm to compress the context and then compute the distortion over the performance of the black-box LLM.

import os
import re
import sys
import argparse
import random
import json
import itertools
import time

from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, RobertaForTokenClassification
import evaluate
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import pytorch_cos_sim
#from llmlingua import PromptCompressor
from LLMLingua.llmlingua import PromptCompressor

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from Selective_Context.selective_context import SelectiveContext


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

@torch.no_grad()
def get_loss(model, input_ids, attn_mask, labels, metric='log_loss', tokenizer=None, helper_model=None):
    if metric == 'log_loss':
        outputs = model(input_ids, attention_mask=attn_mask, labels=labels)
        return outputs.loss.item(), outputs.logits
    elif metric == 'generation':
        assert helper_model is not None, "helper_model must be provided for args.distortion = generation"
        outputs = model.generate(input_ids, attention_mask=attn_mask, max_length=input_ids.shape[1] + 16, num_beams=1, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id)
        outputs = outputs[:, input_ids.shape[1]:]       # only keep the generated part of the output
        output_str = tokenizer.decode(outputs[0], skip_special_tokens=True)
        #output_str = re.split(r'(?<=\.)', output_str.replace('\n', '').replace('\t', ''))[0]    # filter the output to only keep the first sentence
        ground_truth = tokenizer.decode(labels[0], skip_special_tokens=True)
        embeddings = helper_model.encode([output_str, ground_truth])
        similarity = pytorch_cos_sim(embeddings[0], embeddings[1]).item()
        similarity = (similarity + 1) / 2   # normalize similarity to [0, 1]
        return 1 - similarity, output_str      # 1 - similarity because we want smaller distortion for higher similarity
    elif metric in ['rougeL', 'bertscore']:
        assert helper_model is not None, "helper_model must be provided for args.distortion = rougeL or bertscore"
        outputs = model.generate(input_ids, attention_mask=attn_mask, max_length=input_ids.shape[1] + 16, num_beams=1, no_repeat_ngram_size=2, pad_token_id=tokenizer.eos_token_id)
        outputs = outputs[:, input_ids.shape[1]:]       # only keep the generated part of the output
        output_str = [tokenizer.decode(outputs[0], skip_special_tokens=True)]
        ground_truth = [tokenizer.decode(labels[0], skip_special_tokens=True)]
        if metric == 'rougeL':
            return 1 - helper_model.compute(predictions=output_str, references=ground_truth)['rougeL'], output_str[0]
        elif metric == 'bertscore':
            return 1 - helper_model.compute(predictions=output_str, references=ground_truth, lang='en')['f1'][0], output_str[0]


    #elif metric == 'accuracy':
    #    #return (outputs.logits.argmax(-1) == labels).float().mean().item(), outputs.logits
    #    assert input_ids.shape[0] == 1, "Batch size must be 1 for args.distortion=accuracy"
    #    gt = labels[:, -1].item()
    #    input_ids = input_ids[:, :-1]
    #    attn_mask = attn_mask[:, :-1]
    #    outputs = model(input_ids, attention_mask=attn_mask)
    #    #response = model.generate(input_ids, attention_mask=attn_mask, max_length=input_ids.shape[1] + 4, num_beams=1, no_repeat_ngram_size=2, early_stopping=True)
    #    #print(response)
    #    #print(tokenizer.decode(input_ids[0], skip_special_tokens=True))
    #    #print(tokenizer.decode(response[0], skip_special_tokens=True))
    #    pred = outputs.logits.argmax(-1)[:, -1].item()
    #    return 1 - int(pred == gt), outputs.logits      # 1 if incorrect (high distortion), 0 if correct (low distortion)

def get_input(context, query, answer, tokenizer, metric, device='cpu'):
    if metric == "log_loss":
        input_str = f"Instruction: Generate an answer based on the context and query provided.\n\nContext: {context}\nQuery: {query}\nAnswer: {answer}"
        #input_str = f"Context: {context}\nQuery: {query}\nAnswer: {answer}"
        input_ids = tokenizer(input_str, return_tensors="pt")["input_ids"].to(device)
        answer_ids = tokenizer(answer, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
        labels = input_ids.clone()
        labels[:, :-answer_ids.shape[1]] = -100
        attn_mask = torch.ones_like(input_ids)
    elif metric in ["generation", "rougeL", "bertscore"]:
        input_str = f"Instruction: Generate an answer based on the context and query provided.\n\nContext: {context}\nQuery: {query}\nAnswer: "
        #input_str = f"Context: {context}\nQuery: {query}\nAnswer: "
        input_ids = tokenizer(input_str, return_tensors="pt")["input_ids"].to(device)
        labels = tokenizer(answer, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
        attn_mask = torch.ones_like(input_ids)
    else:
        raise ValueError(f"Invalid metric {metric}")

    return input_ids, attn_mask, labels


def main(args):
    seed_everything(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if args.dtype == "fp16":
        torch_dtype = torch.float16
    elif args.dtype == "fp32":
        torch_dtype = torch.float32
    elif args.dtype == "bf16":
        torch_dtype = torch.bfloat16
    else:
        raise ValueError(f"Invalid dtype {args.dtype}")

    enc_model_name = args.enc_model_id.split("/")[-1]
    dec_model_name = args.dec_model_id.split("/")[-1]
    data_filename = args.data_path.split("/")[-1].split(".")[0]
    if enc_model_name == dec_model_name:        # NOTE: this will not work for two different finetunings of the same model
        model_name = enc_model_name
    else:
        model_name = None
    filename = os.path.join(args.save_dir, (f"{enc_model_name}_" if (args.mode not in ["inference", "optimal"] and model_name is not None) else "") + f"{dec_model_name}" + (f"_iter_size={args.iter_size}" if (args.mode == "llmlingua" or args.mode == "llmlingua_query") else "") + (f"_ratio={args.ratio}" if args.mode in ["llmlingua", "selective", "llmlingua_query", "llmlingua2", "llmlingua2_dynamic"] else "") + f"_{args.distortion}" + f"_{args.mode}" + ("_ft" if args.from_finetuned else "") + f"_{data_filename}" + ".jsonl")
    print(f"Saving results to {filename}")

    # check if filename exists. if it does, delete it
    if os.path.exists(filename):
        print(f"File {filename} already exists. Deleting it.")
        os.remove(filename)

    if args.mode == "inference":
        dec_tokenizer = AutoTokenizer.from_pretrained(args.dec_model_id)
        dec_model = AutoModelForCausalLM.from_pretrained(args.dec_model_id, torch_dtype=torch_dtype).to(device)
        dec_model.eval()
    elif args.mode in ["llmlingua", "llmlingua_query"]:
        enc_tokenizer = AutoTokenizer.from_pretrained(args.enc_model_id)
        dec_tokenizer = AutoTokenizer.from_pretrained(args.dec_model_id)
        compressor = PromptCompressor(args.enc_model_id, device_map=device, torch_dtype=torch_dtype)     # TODO: we have no control over the dtype here...
        compressor.model.eval()
        if model_name:
            dec_model = compressor.model
        else:
            dec_model = AutoModelForCausalLM.from_pretrained(args.dec_model_id, torch_dtype=torch_dtype).to(device).eval()
    elif args.mode in ["llmlingua2", "llmlingua2_dynamic"]:
        enc_tokenizer = AutoTokenizer.from_pretrained(args.enc_model_id)
        dec_tokenizer = AutoTokenizer.from_pretrained(args.dec_model_id)
        enc_model = RobertaForTokenClassification.from_pretrained(enc_model_id, torch_dtype=torch_dtype).to(device).eval()
        dec_model = AutoModelForCausalLM.from_pretrained(args.dec_model_id, torch_dtype=torch_dtype).to(device).eval()
    elif args.mode == "selective":
        enc_tokenizer = AutoTokenizer.from_pretrained(args.enc_model_id)
        dec_tokenizer = AutoTokenizer.from_pretrained(args.dec_model_id)
        compressor = SelectiveContext(args.enc_model_id, torch_dtype=torch_dtype)
        dec_model = compressor.model.eval()
    elif args.mode == "optimal":
        dec_tokenizer = AutoTokenizer.from_pretrained(args.dec_model_id)
        dec_model = AutoModelForCausalLM.from_pretrained(args.dec_model_id, torch_dtype=torch_dtype).to(device)
        dec_model.eval()
        with open("./data/compressed_contexts.json", "r") as f:
            compressed_contexts = json.load(f)

    with open(args.data_path, "r") as f:
        data = json.load(f)

    if args.distortion == "generation":
        helper_model = SentenceTransformer("all-mpnet-base-v2", device=device)
        dec_model.generation_config.pad_token_ids = dec_tokenizer.pad_token_id
        with open("./data/examples.json", "r") as f:
            few_shot_examples = json.load(f)
        template = "Instruction: Generate an answer based on the context and query provided.\n\n"
        template += "\n\n".join(f"Context: {info['context']}\nQuery: {info['query']}\nAnswer: {info['answer']}" for info in few_shot_examples.values())
        template += "\n\nContext: {}\nQuery: {}\nAnswer:"
    elif args.distortion == "rougeL":
        helper_model = evaluate.load("rouge")
    elif args.distortion == "bertscore":
        helper_model = evaluate.load("bertscore")
    elif args.distortion == "log_loss":
        helper_model = None
    else:
        raise ValueError(f"Invalid distortion metric {args.distortion}")

    start_time = time.time()
    idx = 0
    for key in tqdm(data.keys()):
        context = data[key]["context"]
        queries = data[key]["queries"]
        answers = data[key]["answers"]

        for query, answer in zip(queries, answers):
            if args.mode == "inference":
                input_ids, attn_mask, labels = get_input(context, query, answer, dec_tokenizer, metric=args.distortion, device=device)
                loss, output = get_loss(dec_model, input_ids, attn_mask, labels, args.distortion, tokenizer=dec_tokenizer, helper_model=helper_model)
                if args.distortion == "generation":
                    loss_dict = {"idx": idx, "context": context, "query": query, "answer": answer, "response": output, "loss": loss}
                else:
                    loss_dict = {"idx": idx, "context": context, "query": query, "answer": answer, "loss": loss}
            elif args.mode in ["llmlingua", "llmlingua_query"]:
                if args.mode == "llmlingua_query":
                    compressed_prompt = compressor.compress_prompt(
                        context,
                        instruction="",
                        question=query,
                        rate=args.ratio,
                        iterative_size=args.iter_size,
                        concate_question=False,
                        condition_compare=True,
                        use_context_level_filter=False
                    )['compressed_prompt']  # NOTE: condition_in_question is ignored for token-level compression
                else:
                    compressed_prompt = compressor.compress_prompt(
                        context,
                        instruction="",
                        question="",
                        rate=args.ratio,
                        iterative_size=args.iter_size,
                        use_context_level_filter=False
                    )['compressed_prompt']

                ratio = len(enc_tokenizer(compressed_prompt, add_special_tokens=False)['input_ids']) / len(enc_tokenizer(context, add_special_tokens=False)['input_ids'])

                input_ids, attn_mask, labels = get_input(compressed_prompt, query, answer, dec_tokenizer, metric=args.distortion, device=device)

                loss, output = get_loss(dec_model, input_ids, attn_mask, labels, args.distortion, tokenizer=dec_tokenizer, helper_model=helper_model)
                if args.distortion == "generation":
                    loss_dict = {"idx": idx, "context": context, "query": query, "answer": answer, "compressed_context": compressed_prompt, "response": output, "loss": loss, "ratio": ratio}
                else:
                    loss_dict = {"idx": idx, "context": context, "query": query, "answer": answer, "compressed_context": compressed_prompt, "loss": loss, "ratio": ratio}
            elif args.mode == "selective":
                compressed_prompt, _ = compressor(
                    context,
                    reduce_ratio=args.ratio,
                    reduce_level='token'
                )

                ratio = len(enc_tokenizer(compressed_prompt, add_special_tokens=False)['input_ids']) / len(enc_tokenizer(context, add_special_tokens=False)['input_ids'])

                input_ids, attn_mask, labels = get_input(compressed_prompt, query, answer, dec_tokenizer, metric=args.distortion, device=device)

                loss, output = get_loss(dec_model, input_ids, attn_mask, labels, args.distortion, tokenizer=dec_tokenizer, helper_model=helper_model)
                if args.distortion == "generation":
                    loss_dict = {"idx": idx, "context": context, "query": query, "answer": answer, "compressed_context": compressed_prompt, "response": output, "loss": loss, "ratio": ratio}
                else:
                    loss_dict = {"idx": idx, "context": context, "query": query, "answer": answer, "compressed_context": compressed_prompt, "loss": loss, "ratio": ratio}
            elif args.mode == "optimal":
                ctx_list = compressed_contexts[key]
                print(f"Context: {context}, Query: {query}, Answer: {answer}")
                context_dict = {}
                for ctx in tqdm(ctx_list):
                    context_ids = torch.tensor([dec_tokenizer.convert_tokens_to_ids(ctx)])
                    answer_ids = dec_tokenizer(answer, add_special_tokens=False, return_tensors="pt")["input_ids"]
                    # TODO: update this to use the new get_input function
                    #if args.distortion == "log_loss":
                    #    input_ids = torch.cat([context_ids, dec_tokenizer(query, add_special_tokens=False, return_tensors="pt")["input_ids"], answer_ids], dim=1).to(device)
                    #    labels = input_ids.clone()
                    #    labels[:, :-answer_ids.shape[1]] = -100
                    #else:
                    #    if args.with_template:
                    #        compressed_context = dec_tokenizer.decode(context_ids[0])
                    #        example = template.format(compressed_context, query, answer)
                    #        input_ids = dec_tokenizer(example, return_tensors="pt")["input_ids"].to(device)
                    #    else:
                    #        input_ids = torch.cat([context_ids, dec_tokenizer(query, add_special_tokens=False, return_tensors="pt")["input_ids"]], dim=1).to(device)
                    #    labels = answer_ids.clone().to(device)
                    #attn_mask = torch.ones_like(input_ids)
                    #ctx = dec_tokenizer.decode(context_ids[0])

                    ctx = dec_tokenizer.convert_tokens_to_string(ctx)
                    input_ids, attn_mask, labels = get_input(ctx, query, answer, dec_tokenizer, metric=args.distortion, device=device)
                    loss, output = get_loss(dec_model, input_ids, attn_mask, labels, args.distortion, tokenizer=dec_tokenizer, helper_model=helper_model)
                    #print('------')
                    #print(dec_tokenizer.decode(input_ids[0], skip_special_tokens=True))
                    #print(output)
                    #print(loss)
                    if args.distortion in ["generation", "rougeL", "bertscore"]:
                        context_dict[ctx] = (output, loss)
                    else:
                        context_dict[ctx] = loss

                loss_dict = {"prompt_id": key, "idx": idx, "context": context, "query": query, "answer": answer, "context_dict": context_dict}
            else:
                raise ValueError("Invalid mode")

            idx += 1

            with open(filename, "a") as f:
                f.write(json.dumps(loss_dict) + "\n")

    end_time = time.time()
    print(f"Total time taken: {end_time - start_time:.2f}s")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run inference on our dataset')
    parser.add_argument('--dec_model_id', type=str, required=True, help='Hugging Face model ID to use')
    parser.add_argument('--enc_model_id', type=str, required=True, help='Hugging Face model ID to use')
    parser.add_argument('--data_path', type=str, default='./data/data.json', help='path to the dataset')
    parser.add_argument('--save_dir', type=str, default='./out', help='path to save the results')
    parser.add_argument('--mode', type=str, required=True, help='operating mode for this script. Must be one of "inference", "llmlingua", "llmlingua_query", "selective", "optimal"')
    parser.add_argument('--distortion', type=str, default='log_loss', help='distortion metric to use. Must be one of "log_loss", "generation"')
    parser.add_argument('--from_finetuned', action='store_true', help='load a finetuned model. --model_id should be the Hugging Face ID of the base model that was finetuned')
    parser.add_argument('--with_template', action='store_true', help='whether to use a template with few-shot examples for generation distortion')
    parser.add_argument('--dtype', type=str, default='fp16', help='data type for the model')
    parser.add_argument('--seed', type=int, default=1337, help='random seed')
    parser.add_argument('--ratio', type=float, default=0.5, help='compression ratio')
    parser.add_argument('--iter_size', type=int, default=4, help='iterative_size for LLMLingua algorithm')

    args = parser.parse_args()
    print(args)

    # if save_dir does not exist, create it
    args.save_dir = os.path.join(args.save_dir, args.mode)
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    #if 'jsonl' not in args.data_path:
    #    # if data_path is a directory, use config_name to find the correct file
    #    if args.mode in ['llmlingua2', 'llmlingua2_dynamic']:
    #        args.data_path = os.path.join(args.data_path, 'dataset', args.config_name, 'val_set_labels' + ('_forced' if args.force_tokenization else '') + '.jsonl')
    #    else:
    #        args.data_path = os.path.join(args.data_path, 'dataset', args.config_name, 'val_set.jsonl')
    #    print(f'Loading data from {args.data_path}')

    main(args)
