import os
from modules import constants
os.environ["HF_HOME"] = os.path.join(constants.DATA_DIR, "raw")
import torch
import pandas as pd

from sae_lens import SAE
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification

from transformer_lens import HookedTransformer

import argparse

from modules.helpers import shuffle_ngrams
import pickle

argparser = argparse.ArgumentParser()
argparser.add_argument(
    "--ngram",
    type=str,
    default="normal",
    help="shuffle to use",
)
args = argparser.parse_args()
ngram = args.ngram


model = HookedTransformer.from_pretrained("gemma-2b-it", device="cuda")

hhem = AutoModelForSequenceClassification.from_pretrained("vectara/hallucination_evaluation_model", trust_remote_code=True, device_map="cuda")

output_dir = os.path.join(constants.DATA_DIR, "interim", "sae_activations", f"{ngram}")
os.makedirs(output_dir, exist_ok=True)

source_df = pd.read_csv(
    os.path.join(output_dir, "gemma-2b-it", "source_text.csv"),
    index_col=0,
)

prompt = "You are a chat bot answering questions using data. You must stick to the answers provided solely by the text in the passage provided. \
You are asked the question 'Provide a concise summary of the following passage, covering the core pieces of information described.' <PASSAGE>\n"


ngrams = ["normal", 1, 6, 30]
output_list = []
hall_score_list = []
for i in range(source_df.shape[0]):
    print(f"Processing {i} of {source_df.shape[0]}")
    source = source_df["source"][i]
    if ngram == "normal":
        new_prompt = prompt.replace("<PASSAGE>", source)
    else:
        ngram = int(ngram)
        shuffled_passage = shuffle_ngrams(source, n=ngram)
        new_prompt = prompt.replace("<PASSAGE>", shuffled_passage)
    
    
    output = model.generate(new_prompt, temperature=0, max_new_tokens=300, top_p=1)
    output_list.append(output)

    # Get the hallucination score
    pair = [(source, output.split(new_prompt)[-1])]
    hall_score = hhem.predict(pair)[0]
    hall_score_list.append(hall_score.item())


summary_df = pd.DataFrame(
    {
        "source": source_df["source"],
        "output": output_list,
        "hall_score": hall_score_list,
    },
    index=source_df.index,
    columns=["source", "summary", "hall_score"],
)

summary_df.to_csv(os.path.join(output_dir, f"summary-gemma-2b-it-FIXED.csv"), index=True)

