import datasets
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import torch
import pandas as pd
import os
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import ast
from sklearn.metrics.pairwise import cosine_similarity
import math
from collections import Counter
import argparse

####PICK DATASET AND MODEL

# Initialize the argument parser
parser = argparse.ArgumentParser(description="Generate answers using a specified model and dataset")

# Define the --model_name argument
parser.add_argument('--model_name', type=str, required=True, choices=['meta-llama/Llama-2-7b-chat-hf','tiiuae/falcon-7b-instruct','mistralai/Mistral-7B-Instruct-v0.1','meta-llama/Llama-2-13b-chat-hf'],help='Name of the model to use')

# Define the --dataset argument
parser.add_argument('--dataset', type=str, required=True, choices=['bioasq', 'nq', 'trivia_qa','squad'],help='Name of the dataset to use')

parser.add_argument('--huggingfacetoken', type=str, required=True, help='Your Hugging Face token')

# Parse the command-line arguments
args = parser.parse_args()

# Access the model_name and dataset values
model_name = args.model_name
data_name = args.dataset
print(f"Using model: {model_name}")
print(f"Using dataset: {data_name}")

#####SET HUGGINGFACE TOKEN
from huggingface_hub import login
huggingface_token = args.huggingfacetoken
login(huggingface_token)

def dataset_function(dataset_name,sample):
  questions = []
  answers = []


  if dataset_name == 'nq':
      dataset = datasets.load_dataset('nq_open')
      train_dataset = dataset["train"]
      validation_dataset = dataset["validation"]
      total_samples = len(train_dataset)
      random_indices = random.sample(range(total_samples), sample)
      random_samples = train_dataset.select(random_indices)

      # # Iterate over the randomly selected samples
      for sample in random_samples:
          question = sample.get('question')
          answer = sample.get('answer')
          if question:
              questions.append(question)

          if answer:
              answers.append(answer[0])


  elif dataset_name == "trivia_qa":
      dataset = datasets.load_dataset('TimoImhof/TriviaQA-in-SQuAD-format')['unmodified']
      dataset = dataset.train_test_split(test_size=0.2)
      train_dataset = dataset['train']
      validation_dataset = dataset['test']
      total_samples = len(validation_dataset)

      random_indices = random.sample(range(total_samples), sample)
      random_samples = validation_dataset.select(random_indices)



      for sample in random_samples:
          question = sample.get('question')
          answer = sample.get('answers', {}).get('text', [])

          if question:
              questions.append(question)

          if answer:
              # Assuming we are only interested in the first answer
              answers.append(answer[0])

  elif dataset_name == 'squad':
      dataset = datasets.load_dataset('rajpurkar/squad')
      train_dataset = dataset['train']
      validation_dataset = dataset['validation']
      total_samples = len(validation_dataset)

      random_indices = random.sample(range(total_samples), sample)
      random_samples = validation_dataset.select(random_indices)



      for sample in random_samples:
          question = sample.get('question')
          answer = sample.get('answers', {}).get('text', [])

          if question:
              questions.append(question)

          if answer:
              answers.append(answer[0])
  elif dataset_name == 'bioasq':
    dataset = pd.read_csv('data/bioasq_exact.csv')
    questions = dataset['questions'].tolist()
    answers = dataset['answers'].tolist()
    total_samples = len(questions)

    random_indices = random.sample(range(total_samples), sample)

    questions = [dataset['questions'].tolist()[i] for i in random_indices]
    answers = [dataset['answers'].tolist()[i] for i in random_indices]

  return questions, answers


print('######downloading data:')
questions, answers = dataset_function(data_name,400)

"""## Choosing LLM & Generating answers"""


def create_model_and_tokenizer(model_name):
    if 'llama' or 'mistral' in model_name:
        tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=True)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto",use_auth_token=True)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
    return model, tokenizer

def generate_direct_answer(question,temp):

    prompt = f"Answer the following question directly, without any elaboration or extra details. Provide only the answer, starting immediately after the word 'Answer' Question: {question}. Answer:"

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    # Generate output
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_length=200,    # Limit the length to keep it concise
            do_sample=True,  # Disable sampling for deterministic output
            temperature=temp,
        )

    # Decode the generated text
    answer = tokenizer.decode(output[0], skip_special_tokens=True).strip()

    # Remove any repetition of the question (in case the model repeats it)
    answer = re.sub(re.escape(question), '', answer).strip()
    answer = answer.replace("Answer:", "").strip()
    answer = answer.replace("Answer the following question directly, without any elaboration or extra details. Provide only the answer, starting immediately after the word 'Answer' Question: .", "").strip()

    return answer.split('\n')[0]

print('######Downloading model:')

model, tokenizer = create_model_and_tokenizer(model_name)

print('#####Generating Low Temp Answer:')

best_answers = []
for question in questions:
    answer = generate_direct_answer(question,0.1)
    best_answers.append(answer)
    print(f"{str(questions.index(question))}-Q: {question}\nA: {answer}\n")


high_temp = []
print('#####Generating High Temp Answers:')
for question in questions:
    l = []
    for i in range(10):
      answer = generate_direct_answer(question,1)
      l.append(answer)
    print(f"{str(questions.index(question))}-Q: {question}\nA: {l}\n")
    high_temp.append(l)



df = pd.DataFrame()
df['q'] = questions
df['answer'] = answers
df['low_temp_1'] = best_answers
df['high_temp_10']= high_temp

"""## Generating embedding for q and answer"""

# Load the Sentence-BERT model
emb_model = 'all-MiniLM-L6-v2'
model = SentenceTransformer(emb_model)

def get_qa_embedding(question, answer=None, is_question=False):
    """
    This function takes a question and answer as input and returns the embedding from Sentence-BERT.
    It uses the Sentence-BERT model to generate sentence embeddings directly.

    Parameters:
    question (str): The input question.
    answer (str): The input answer. If None, only the question is considered.

    Returns:
    cls_embedding_np (numpy array): The sentence embedding representing the question and answer.
    """

    # Prepare input text
    if is_question or answer is None:
        input_text = question
    else:
        input_text = f"{question} [SEP] {answer}"

    # Generate the embedding using Sentence-BERT
    cls_embedding_np = model.encode(input_text)

    return cls_embedding_np

"""## Getting 'ground truth' labels"""

from sklearn.metrics.pairwise import cosine_similarity

print('')
is_hallucination = []
cosineSim = []
for index, row in df.iterrows():
    question = row['q']
    low_temp = row['low_temp_1']
    answer = row['answer']
    low_temp_emb = get_qa_embedding(question, str(low_temp))
    answer_emb = get_qa_embedding(question, str(answer))
    cosine_sim = cosine_similarity([low_temp_emb,answer_emb])
    if cosine_sim[0][1]>=0.95:
      is_hallucination.append('no')
    else:
      is_hallucination.append('yes')
    cosineSim.append(cosine_sim[0][1])



df['lowtemp_hallucination'] = is_hallucination
df['lowtemp_cosineSim'] = cosineSim

"""## Generating the clusters"""
print('#####Generating Clusters')
import ast

def clustering_answers(question,answers):
  clusters = []
  embeddings = []
  context_embedding = get_qa_embedding(question, 'xxx', is_question = True)
  response_embeddings = []
  for a in answers:
    emb = get_qa_embedding(question, str(a))
    response_embeddings.append(emb)


# Compute cosine similarity between embeddings
  cosine_sim_matrix = cosine_similarity(response_embeddings)
  cosine_dist_matrix = 1 - cosine_sim_matrix
  try:
      clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=0.05, affinity='precomputed', linkage='average')
      cluster_labels = clustering.fit_predict(cosine_dist_matrix)
  except TypeError:
      clustering = AgglomerativeClustering(n_clusters=None, distance_threshold=0.05, metric='precomputed', linkage='average')
      cluster_labels = clustering.fit_predict(cosine_dist_matrix)

  return list(cluster_labels)

clusters = []
for index, row in df.iterrows():
    question = row['q']
    high_temp = row['high_temp_10']
    cluster_assignments = clustering_answers(question,high_temp)
    clusters.append(cluster_assignments)
    print(cluster_assignments)

df['cluster_assignments'] = clusters

"""## Calculating Semantic Entropy/Confidence Score"""

def semantic_entropy(cluster_assignments):
    # Step 1: Calculate frequency of each unique number
    frequency = Counter(cluster_assignments)

    # Step 2: Calculate the total number of elements
    total_count = len(cluster_assignments)

    # Step 3: Apply the formula
    entropy = 0
    for count in frequency.values():
        prob = count / total_count
        entropy += prob * math.log(prob)

    # Since the sum is negative
    return -entropy

semantic_entropies = []
for clusters in list(df['cluster_assignments']):
    semantic_entropies.append(semantic_entropy(clusters))

df['semantic_entropy_score'] = semantic_entropies

"""## Calculating ROC Score"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.preprocessing import MinMaxScaler

# Calculate ROC curve and AUC
binary_list = [1 if answer == "yes" else 0 for answer in df['lowtemp_hallucination']]
scaler = MinMaxScaler()
y_scores = scaler.fit_transform(np.array(df['semantic_entropy_score']).reshape(-1, 1)).flatten()

fpr, tpr, thresholds = roc_curve(binary_list, y_scores)
auc = roc_auc_score(binary_list, y_scores)
print('Final ROC-AUC: ', auc)

# Plot the ROC curve
plt.figure()
plt.plot(fpr, tpr, color='blue', lw=2, label=f'AUC = {auc:.3f}')
plt.plot([0, 1], [0, 1], color='grey', lw=2, linestyle='--')  # (random classifier)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.savefig(f'roc_figures/roc_{model_name}_{data_name}', dpi=300, format='png')

plt.show()
