import csv
import numpy as np
import logging
import os
import pandas as pd
import time
import warnings
from tqdm import tqdm

from datasets import load_dataset
from rank_bm25 import BM25Okapi
from tokenizers import Tokenizer
from tokenizers.models import WordPiece

from core.llm import LLM
from core.messages import Message, Role, merge_messages
from evaluation.utils import get_bm25_context, get_prompt_context, get_gt_answer

def cosine_similarity_vectorized(a, B):
    # a is the query embedding (shape: [embedding_dim])
    # B is the matrix of document embeddings (shape: [num_documents, embedding_dim])
    a_norm = np.linalg.norm(a)
    B_norms = np.linalg.norm(B, axis=1)
    cosine_similarities = np.dot(B, a) / (B_norms * a_norm)
    return cosine_similarities


def sample_answer(
    llm: LLM,
    prompt: str,
    temperature: float,
    max_total_tokens: int,
    max_new_tokens: int,
) -> str:

    n_tokens_prompt = llm.tokenize(prompt).shape[1]
    max_tokens_to_generate = max_total_tokens - n_tokens_prompt
    if max_new_tokens > 0:
        max_tokens_to_generate = min(max_new_tokens, max_tokens_to_generate)
    if max_tokens_to_generate <= 10:
        warnings.warn(f"Too few tokens left for the answer: {max_tokens_to_generate}.", stacklevel=2)
    elif max_tokens_to_generate <= 0:
        raise ValueError(f"Too many tokens in the prompt: {n_tokens_prompt}, while the limit is {max_total_tokens}.")

    messages = [Message(Role.USER, prompt)]
    answer, truncated = llm.call(
        messages,
        temperature=temperature,
        max_new_tokens=max_tokens_to_generate,
    )
    answer = answer.replace("&lt;", "<").replace("&gt;", ">")
    return answer 


def get_openai_embedding(text):
    client = OpenAI()
    response = client.embeddings.create(
        input=text,
        model="text-embedding-3-small"  # Choose the embedding model
    )
    return response.data[0].embedding


def main(
    base: str = "llama3-8b-instruct",
    adapter_id: str = "",
    temperature: float = 0,
    dataset: str = "nyt",
    output_filename: str = "output_rewritten.csv",
    n_questions: int = 1000,
    bm25: bool = False,
    openai_rag: bool = False,
    n_questions_bm25: int = 1000,
    n_documents_bm25: int = 7,
    oracle: bool = False,
    rewritten_questions: bool = True,
):
    # BM25 and Oracle cannot be used at the same time
    assert not (bm25 and oracle)
    if openai_rag:
        assert rewritten_questions

    if "llama" in base:
        opening_message = Message(
            Role.SYSTEM,
            "You are a knowledgeable assistant trained to provide accurate and helpful information. Please respond to the user's queries promptly."
        )
    elif "qwen" in base:
        opening_message = Message(
            Role.SYSTEM,
            "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
        )
    else:
        raise NotImplementedError(f"Unknown base {base}")

    dataset_name = dataset
    output_directory_name = dataset
    dataset = load_dataset("squadshifts", dataset, trust_remote_code=True)["test"]
   
    if rewritten_questions:
        assert n_questions <= 1000
        if bm25 or openai_rag:
            assert n_questions_bm25 == 1000
        assert "rewritten" in output_filename
        path = f"./datasets/{dataset_name}_filtered.csv"
        with open(path, 'r') as f:
            all_questions = [line.strip() for line in f.readlines()]
    
    from core.llm import LLM
    if adapter_id:
        llm = LLM.from_adapter(adapter_id, opening_message=opening_message)
    else:
        llm = LLM(base, opening_message=opening_message)
    llm.load_model()
    
    if bm25 or openai_rag:
        # Tokenization setup for BM25
        tokenizer = llm.tokenizer
        contexts = []
        tokenized_contexts = []
        context_set = set()
        prev_context = None
        for i, item in enumerate(dataset):
            if n_questions_bm25 > 0 and i >= n_questions_bm25:
                break
            context = get_bm25_context(item)
            for c in context:
                context_set.add(c)
            if context == prev_context:
                continue
            prev_context = context
            for paragraph in context:
                contexts.append(paragraph)
                tokenized_contexts.append(tokenizer.encode(paragraph))
        print("Number of contexts", len(context_set), flush=True)
        print("Number of tokens", len([t for x in tokenized_contexts for t in x]))

        if bm25:
            bm25_retriever = BM25Okapi(tokenized_contexts)
            if 'bm25' not in output_filename:
                print("WARNING: bm25 not in output_filename. Renaming it")
                output_filename = "output_bm25.csv"
        elif openai_rag:
            embeddings_path = f'./datasets/{dataset_name}_embeddings.csv'
  
            # Check if the embeddings file exists
            if os.path.exists(embeddings_path):
                # Load embeddings from the CSV file
                document_embeddings = pd.read_csv(embeddings_path, header=None).values
                print(f"Loaded embeddings from {embeddings_path}")
            else:
                # Generate embeddings for documents and save them
                print(f"Embedding {len(contexts)} documents", flush=True)
                document_embeddings = np.array([get_openai_embedding(doc) for doc in tqdm(contexts)])
                pd.DataFrame(document_embeddings).to_csv(embeddings_path, index=False, header=False)
                print(f"Generated and saved embeddings to {embeddings_path}")

            questions_path = f'./datasets/{dataset_name}_questions_embeddings.csv'
            # Check if the embeddings file exists
            if os.path.exists(questions_path):
                # Load embeddings from the CSV file
                questions_embeddings = pd.read_csv(questions_path, header=None).values
                print(f"Loaded embeddings from {questions_path}")
            else:
                # Generate embeddings for documents and save them
                questions_embeddings = np.array([get_openai_embedding(q) for q in tqdm(all_questions)])
                pd.DataFrame(questions_embeddings).to_csv(questions_path, index=False, header=False)
                print(f"Generated and saved embeddings to {questions_path}")
                exit(0)

    elif oracle and 'oracle' not in output_filename:
        print("WARNING: oracle not in output_filename. Renaming it")
        output_filename = "output_oracle.csv"

    if adapter_id:
        # -2 because last one is /merged
        marker = "checkpoints/huggingface/"
        if marker in adapter_id:
            adapter_name = adapter_id.split(marker, 1)[1]
        else:
            adapter_name = adapter_id.split('/')[-1]
        output_dir = f"outputs/{output_directory_name}/{adapter_name}"
    else:
        output_dir = f"outputs/{output_directory_name}/{base}"
    os.makedirs(output_dir, exist_ok=True)
    output_file = output_dir + "/" + output_filename

    print(f"Writing answers to output_file {output_file}", flush=True)

    if os.path.exists(output_file):
        print(f"File {output_file} exists already. Not re-creating it", flush=True)
        exit(0)

    questions = []
    gt_answers = []
    answers = []
    prompts = []
    for i, item in enumerate(dataset):
        if n_questions > 0 and i >= n_questions:
            break
        if rewritten_questions:
            question = all_questions[i]
        else:
            question = item['question']

        if openai_rag or bm25:
            if openai_rag:
                embedding = questions_embeddings[i]
                similarities = cosine_similarity_vectorized(embedding, document_embeddings)
                top_n_indices = similarities.argsort()[-n_documents_bm25:][::-1]
                top_n_documents = [contexts[i] for i in top_n_indices]
                context = "\n\n".join(top_n_documents)
            elif bm25:
                context_tokens = tokenizer.encode(question)
                top_docs = bm25_retriever.get_top_n(context_tokens, tokenized_contexts, n=n_documents_bm25)
                context = '\n\n'.join([tokenizer.decode(doc).replace('<|begin_of_text|>', '') for doc in top_docs])
            prompt = f"{context}\n\nQuestion: {question}"
        elif oracle:
            context = get_prompt_context(item)
            prompt = f"{context}\n\nQuestion: {question}"
        else:
            prompt = question

        gt_answer = str(get_gt_answer(item))

        questions.append(question)
        gt_answers.append(gt_answer)

        answer = sample_answer(
            llm=llm,
            prompt=prompt,
            temperature=temperature,
            max_total_tokens=6128,
            max_new_tokens=500,
        )
        answers.append(answer)
        prompts.append(prompt)
        
    with open(output_file, 'w', newline='') as file:
        print(f"Writing answers to {output_file}", flush=True)
        writer = csv.writer(file, delimiter=';', quotechar='"', quoting=csv.QUOTE_MINIMAL, escapechar='\\')
        for (prompt, gt_answer, answer, question) in zip(prompts, gt_answers, answers, questions):
            writer.writerow([question.replace('\n', ''), gt_answer.replace('\n', ''), answer.replace('\n', ' ')])
            file.flush()
    print("Writing finished", flush=True)

if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
