import os
import copy
import time
import json
import csv
import argparse
import pickle as pkl
import re
from typing import List, Dict
import numpy as np
import torch
import torch.nn.functional as F
import faiss
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModel
from vllm import LLM, SamplingParams


###################################
# LLM Helper Functions
###################################
def call_llm(prompt: str, model, tokenizer, sampling_params) -> str:
    """
    Calls the LLM using vllm's generate function.
    Applies a chat template and returns the generated text.
    """
    # Prepare the chat-style input for the LLM
    prompt_input = [{"role": "user", "content": prompt.strip()}]
    # print("LLM prompt:", prompt, prompt_input)
    
    # Apply chat template to build full prompt text
    text = tokenizer.apply_chat_template(
        prompt_input,
        tokenize=False,
        add_generation_prompt=True
    )
    outputs = model.generate([text], sampling_params)
    # Retrieve the generated text from output
    generated_text = outputs[0].outputs[0].text
    # print("LLM output:", generated_text)
    return generated_text.strip()


def decompose_question(question: str, prompt_template: str, llm_model, tokenizer, sampling_params) -> List[Dict]:
    """
    Decomposes the main question into sub-questions using an LLM.
    Assumes the LLM returns a structured output with each line containing a sub-question and context indicator.
    """
    # Replace placeholder in the prompt template with the original question
    call_prompt = prompt_template.replace("${question}", question)
    response = call_llm(call_prompt, model=llm_model, tokenizer=tokenizer, sampling_params=sampling_params)
    
    decomposed_questions = []
    # Parse each line for label, text, and context requirement
    for line in response.strip().split("\n"):
        line = line.strip()
        if line.startswith("### Q"):
            # Expected format: "Q1: ... ## Need Context? ## Yes"
            question_part, context_part = line.split("## Need Context? ##")
            question_text = question_part.split(":", 1)[1].strip()  # Extract the actual question text
            needs_context = context_part.strip().lower().startswith("yes")
            # Extract question label (e.g., "Q1")
            q_label = question_part.split(":")[0].strip("#").strip()
            decomposed_questions.append({
                "label": q_label, 
                "text": question_text,
                "needs_context": needs_context
            })
    return decomposed_questions


###################################
# Retrieval Functions
###################################
def load_embedding(dataset: str, embedding_model: str):
    sentences, passage_embeddings = [], []
    with open(f"{dataset}/corpus.tsv", "r") as f:
        reader = csv.reader(f, delimiter='\t')
        for lines in tqdm(reader):
            if lines[0] == "id":
                continue
            sentences.append(lines[1])

    for i in trange(4):
        path = f"{dataset}/{embedding_model}/embeddings-{i}-of-4.pkl"
        with open(path, "rb") as f:
            passage_embeddings.append(pkl.load(f))
    passage_embeddings = np.concatenate(passage_embeddings, axis=0)
    print("Passage Size:", passage_embeddings.shape)
    return sentences, passage_embeddings

def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    """
    Performs average pooling on the token embeddings using the attention mask.
    """
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def embed_text(text: str, tokenizer, model, embedding_model_name: str) -> torch.Tensor:
    """
    Converts text to an embedding vector using the specified embedding model.
    Adjusts tokenization based on model name.
    """
    if "e5-large" in embedding_model_name:
        encoded_input = tokenizer("query: " + text, max_length=128, padding=True, truncation=True, return_tensors='pt')
    else:
        encoded_input = tokenizer(text, max_length=100, padding=True, truncation=True, return_tensors='pt')
    # Move input tensors to GPU
    for key in encoded_input:
        encoded_input[key] = encoded_input[key].cuda()
    
    # Compute embeddings based on model type
    with torch.no_grad():
        if embedding_model_name in ["gte-base"]:
            model_output = model(**encoded_input)
            sentence_embeddings = model_output.last_hidden_state[:, 0].detach().cpu()
        elif embedding_model_name in ["e5-large"]:
            outputs = model(**encoded_input)
            embeddings = average_pool(outputs.last_hidden_state, encoded_input['attention_mask'])
            sentence_embeddings = embeddings.detach().cpu()
        elif embedding_model_name in ["dragon"]:
            embeddings = model(**encoded_input, output_hidden_states=True, return_dict=True).last_hidden_state[:, 0, :]
            sentence_embeddings = embeddings.detach().cpu()
    
    if embedding_model_name in ["e5-large", "gte-base"]:
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1).squeeze(1)
    return sentence_embeddings


def retrieve_context(query: str, cpu_index, corpus: List[str], embedding_tokenizer, embedding_model, embedding_model_name: str, top_k: int = 3) -> List[str]:
    """
    Retrieves the top-k context passages from the vector store for a given query.
    """
    query_embedding = embed_text(query, embedding_tokenizer, embedding_model, embedding_model_name)
    dev_D, dev_I = cpu_index.search(query_embedding, top_k)
    passages = [corpus[r] for r in dev_I[0]]
    return passages


###################################
# Answering Functions
###################################
def zigzag_visit(lst: List) -> List:
    """
    Reorders the input list in a zigzag fashion.
    Example:
        Input: [1, 2, 3, 4, 5, 6, 7]
        Output: [1, 3, 5, 7, 6, 4, 2]
    """
    n = len(lst)
    result = [None] * n
    
    # Fill first half (odd indices)
    i, j = 0, 0
    while j < (n + 1) // 2:
        result[j] = lst[i]
        i += 2
        j += 1

    # Fill second half (even indices)
    i = 1
    j = n - 1
    while j >= (n + 1) // 2:
        result[j] = lst[i]
        i += 2
        j -= 1
    
    return result

def answer_sub_claim(sub_q: str, context_passages: List[str], model, tokenizer, sampling_params) -> str:
    """
    Uses the LLM to answer a sub-question given the retrieved context.
    The context passages are reordered in a zigzag manner before being concatenated.
    """
    reordered_passages = zigzag_visit(context_passages)
    context_text = "\n\n".join(reordered_passages)
    prompt = f"""You have the following context passages:
{context_text}

Please verify whether the claim '{sub_q}' is correct using the context as reference. 
If no answer is found in the context, use your own knowledge.
Please only output Yes or No and do not give any explanation."""

    response = call_llm(prompt, model=model, tokenizer=tokenizer, sampling_params=sampling_params)
    return response.strip()

def answer_sub_question(sub_q: str, context_passages: List[str], model, tokenizer, sampling_params) -> str:
    """
    Uses the LLM to answer a sub-question given the retrieved context.
    The context passages are reordered in a zigzag manner before being concatenated.
    """
    reordered_passages = zigzag_visit(context_passages)
    context_text = "\n\n".join(reordered_passages)
    prompt = f"""You have the following context passages:
{context_text}

Please answer the question '{sub_q}' with a short span using the context as reference.
If no answer is found in the context, use your own knowledge. Your answer needs to be as short as possible."""
#     f"""You have the following context passages:
# {context_text}

# Question: {sub_q}

# Please answer the above question with one or a list of entities using the context as reference. 
# If no answer is found in the context, use your own knowledge.
# Do not give any explanation. Your answer needs to be as short as possible."""
    response = call_llm(prompt, model=model, tokenizer=tokenizer, sampling_params=sampling_params)
    return response.strip()


###################################
# Orchestration Functions
###################################
def replace_placeholders(question_text: str, answers_so_far: Dict[str, str]) -> str:
    """
    Replaces placeholders like "#1", "#2", etc. in the question text with answers from previous sub-questions.
    """
    matches = re.findall(r"#(\d+)", question_text)
    for m in matches:
        placeholder = f"#{m}"
        q_key = f"Q{m}"
        if q_key in answers_so_far:
            question_text = question_text.replace(placeholder, answers_so_far[q_key])
    return question_text


def generate_final_answer(original_question: str, sub_questions: Dict[str, str], sub_answers: Dict[str, str],
                          model, tokenizer, sampling_params, dataset: str, passages: List[str] = None, add_passage: int = 1) -> str:
    """
    Generates a final answer for the original question by summarizing sub-question answers.
    """
    sub_answer_text = "\n".join([f"### {k}: {sub_questions[k]}, Answer for {k}: {v}" for k, v in sub_answers.items()])
    final_prompt = ("True or False only." 
                    if dataset in ["strategyqa"] 
                    else "a short span")
    
#     prompt = f"""For the question: {original_question}

# We have the following decomposed sub-questions and sub-answers:
# {sub_answer_text}

# Based on these, provide the final concise answer to the original question: "{original_question}".
# Do not give an explanation. {final_prompt}"""

    if dataset in ["hover", "exfever"]:
        prompt = f"""You are given some subquestions and their answers:
{sub_answer_text}

Please verify the correctness of the claim: '{original_question}' using the subquestions as reference. Please provide a concise and clear reasoning followed by a concise conclusion. Your answer should be Yes or No only. 
Wrap your answer with <answer> and </answer> tags."""

    else:
        if add_passage:
            passages = "\n\n".join(list(set(passages)))
            prompt = f"""You have the following passages:
{passages}

You are also given some subquestions and their answers:
{sub_answer_text}

Please answer the question '{original_question}' with {final_prompt} using the documents and subquestions as reference.
Make sure your response is grounded in documents and provides clear reasoning followed by a concise conclusion. If no relevant information is found, use your own knowledge. 
Wrap your answer with <answer> and </answer> tags."""
        else:
            prompt = f"""You are given some subquestions and their answers:
{sub_answer_text}

Please answer the question '{original_question}' with {final_prompt} using the subquestions as reference. Provides clear reasoning followed by a concise conclusion. If no relevant information is found, use your own knowledge. 
Wrap your answer with <answer> and </answer> tags."""


    final = call_llm(prompt, model=model, tokenizer=tokenizer, sampling_params=sampling_params)
    return final.strip()


def multi_turn_qa(question: str, sub_questions: List[Dict], answer: str,
                  embedding_tokenizer, embedding_model, embedding_model_name: str,
                  llm_model, llm_tokenizer, sampling_params, dataset: str, add_passage, topk: int):
    """
    Orchestrates the multi-turn QA process:
    1. Resolve any placeholder references in sub-questions.
    2. Retrieve context if needed.
    3. Answer each sub-question.
    4. Combine sub-answers into a final answer.
    """
    # Create dictionaries to hold resolved sub-questions and answers
    subquestions_dict = {subq_dict["label"]: subq_dict["text"] for subq_dict in sub_questions}
    answer_dict = {}
    passage_dict = {}
    all_passages = []
    # Process each sub-question
    for subq_dict in sub_questions:
        q_label = subq_dict["label"]
        q_text = subq_dict["text"]
        # needs_context = subq_dict["needs_context"]
        
        # Replace placeholders (e.g., #1, #2) with previous answers
        q_text_resolved = replace_placeholders(q_text, answer_dict)
        
        passages = []
        # Retrieve context if required
        # if needs_context or needs_context == "":
        passages = retrieve_context(q_text_resolved, cpu_index, corpus, embedding_tokenizer,
                                        embedding_model, embedding_model_name, top_k=topk)
        all_passages += passages[:5] if len(sub_questions) <= 3 else passages[:3]
        all_passages = list(set(all_passages))
        # Answer the sub-question
        sub_answer = answer_sub_question(q_text_resolved, passages, llm_model, llm_tokenizer, sampling_params)
        answer_dict[q_label] = sub_answer
        passage_dict[q_label] = passages
        subquestions_dict[q_label] = q_text_resolved

    # Generate final answer based on sub-answers
    final_answer = generate_final_answer(question, subquestions_dict, answer_dict,
                                         llm_model, llm_tokenizer, sampling_params, dataset, all_passages, add_passage)
    print("-------\nquestion:", question,
          "\nsub-questions:", sub_questions, 
          "\nanswers:", answer_dict, 
          "\nFinal Answer:", final_answer, 
          "\ngold answer:", answer)
    return final_answer, answer_dict, passage_dict


###################################
# Data Loading and Index Building
###################################
def load_data(dataset: str, expname: str, save_dir: str) -> List[Dict]:
    """
    Loads questions from a JSONL file for the given dataset.
    """
    questions = []
    if "-" in dataset:
        dataset = dataset.split("-")[0]
    with open(f"{save_dir}/{dataset}/prompts_decompose_test_t0.0_{expname}/generate.jsonl", "r") as f:
        for line in f:
            data = json.loads(line)
            questions.append(data)
    print("========")
    print(f"Loaded {len(questions)} examples from {dataset}!")
    print("========")
    return questions


def build_index(dataset: str, embedding_model_name: str):
    """
    Builds a FAISS index from pre-computed embeddings.
    """
    corpus, embeddings = load_embedding(dataset, embedding_model_name)
    dim = embeddings.shape[1]
    faiss.omp_set_num_threads(32)
    cpu_index = faiss.IndexFlatIP(dim)
    cpu_index.add(embeddings)
    return corpus, embeddings, cpu_index


###################################
# Main Execution
###################################
if __name__ == "__main__":
    parser = argparse.ArgumentParser("")
    parser.add_argument("--llm_model_path", type=str, default="")
    parser.add_argument("--llm_tokenizer", type=str, default="")
    # parser.add_argument("--decompose_llm", type=str, default="gpt-4o-mini")
    parser.add_argument("--dataset", type=str, default="hotpotqa")
    parser.add_argument("--expname", type=str, default="")
    parser.add_argument("--save_dir", type=str, default="")
    parser.add_argument("--top_p", type=float, default=0.99)
    parser.add_argument("--k", type=int, default=10)
    parser.add_argument("--add_passage", type=int, default=1)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--tensor_parallel_size", type=int, default=1)
    parser.add_argument("--sentence_embedding_model", type=str)
    parser.add_argument("--sentence_embedding_model_save_name", type=str, default="e5-base", 
                        choices=["dragon", "gte-base", "e5-base"])
    args = parser.parse_args()

    # Load questions and build the FAISS index
    questions = load_data(args.dataset, args.expname, args.save_dir)
    tokenizer = AutoTokenizer.from_pretrained(args.sentence_embedding_model, 
                                                trust_remote_code=True)
    embedding_model = AutoModel.from_pretrained(args.sentence_embedding_model, 
                                                trust_remote_code=True).cuda()
    embedding_model.eval()
    corpus, embeddings, cpu_index = build_index(args.dataset, args.sentence_embedding_model_save_name)
    
    # Load or initialize visited indices for tracking processed examples
    saved_examples = []
    visited_idx = {}
    os.makedirs(f"{args.save_dir}/output/{args.dataset}/", exist_ok=True)
    # visited_path = f"{args.save_dir}/output/{args.dataset}/prompts_decompose_plain_{args.expname}-{args.llm_model}_visited.json"
    # if os.path.exists(visited_path):
    #     with open(visited_path, "r") as f:
    #         visited_idx = json.load(f)
    #     print(f"Loaded {len(visited_idx)} visited examples!")

    # Load the LLM and its tokenizer
    print("LLM Tokenizer:", args.llm_tokenizer)
    llm_tokenizer = AutoTokenizer.from_pretrained(args.llm_tokenizer)
    sampling_params = SamplingParams(
        temperature=args.temperature, 
        top_p=args.top_p, 
        repetition_penalty=1.05, 
        max_tokens=512
    )
    model_path = args.llm_model_path
    llm = LLM(model=model_path, tensor_parallel_size=args.tensor_parallel_size,
              gpu_memory_utilization=0.85, trust_remote_code=True)

    # Process each question (limited to the first 5 for a toy run)
    index = 0
    for item in tqdm(questions):
        try:
            if index not in visited_idx:
                question = item["question"]
                answer = item["answer"]
                decomposed_question = item["decomposed"]
                final_answer, intermediate_answers, intermediate_passages = multi_turn_qa(
                    question, decomposed_question, answer,
                    tokenizer, embedding_model, args.sentence_embedding_model_save_name,
                    llm, llm_tokenizer, sampling_params, args.dataset, args.add_passage, args.k
                )
                new_item = copy.deepcopy(item)
                new_item["index"] = index
                new_item["final_answer"] = final_answer
                new_item["intermediate_answers"] = intermediate_answers
                new_item["intermediate_passages"] = intermediate_passages
                saved_examples.append(new_item)
                visited_idx[index] = 1
            index += 1            
        except Exception as e:
            print(f"Error processing item {item}: {e}")
            index += 1
            continue

    # Save results and visited indices if any examples were processed
    if saved_examples:
        # with open(visited_path, "w") as f:
        #     json.dump(visited_idx, f, indent=2)
        output_path = f"{args.save_dir}/{args.dataset}/prompts_decompose_test_t0.0_{args.expname}/test_{args.sentence_embedding_model_save_name}_k{args.k}_passage{args.add_passage}.jsonl"
        with open(output_path, "w") as f:
            for saved_example in tqdm(saved_examples):
                f.write(json.dumps(saved_example) + '\n')