import re
import torch
import numpy as np
import random
from sentence_transformers import SentenceTransformer
from FlagEmbedding import BGEM3FlagModel


class RAG:
    def __init__(self, device, context_window, tokenizer, is_lexical=False):
        if is_lexical:
            self.embedding_model = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True, device=device)
        else:
            self.embedding_model = SentenceTransformer('BAAI/bge-m3', device=device)
        self.is_lexical = is_lexical
        self.context_window = context_window
        self.tokenizer = tokenizer

    def process(self, inputs, questions, is_batch):
        # chunk
        batch_chunks = self.split_into_word_chunks(inputs, is_batch)

        # retrieve
        if is_batch:
            batch_inputs = [self.retreive(question, batch_chunk) for question, batch_chunk in zip(questions, batch_chunks)]
        else:
            batch_inputs = self.retreive(questions, batch_chunks)

        return batch_inputs

    def retreive(self, question, chunks):
        if self.is_lexical:
            emb_q =  self.embedding_model.encode([question], return_sparse=True)
            emb_chunks = self.embedding_model.encode(chunks, return_sparse=True)
            similarities = self.embedding_model.compute_lexical_matching_score(
                emb_q['lexical_weights'], 
                emb_chunks['lexical_weights']
            )[0]
        else:
            emb_q =  self.embedding_model.encode([question])
            emb_chunks = self.embedding_model.encode(chunks)
            similarities = self.embedding_model.similarity(emb_q, emb_chunks)[0]
        # print(f"SIMILARITY SHAPE: {similarities.shape}")

        descending_idx = np.argsort(-similarities)

        # Fill in chunks until the context window is filled 
        new_input = ""
        current_num_tokens = 0

        for i in range(len(chunks)):
            # add chunk to the new_input
            selected_chunk = chunks[descending_idx[i]]

            new_input += f"{selected_chunk}\n\n"
            # update num tokens
            num_tokens = len(self.tokenizer.encode(selected_chunk))
            current_num_tokens += num_tokens

            # if num_token exceeds break
            if current_num_tokens >= self.context_window:
                print(f"\t\tTotal number of tokens: {current_num_tokens}")
                print(f"\t\tNUMBER of chunks: {i+1}")
                break 
        
        return new_input
    
    def split_into_word_chunks(self, texts, is_batch):
        if is_batch: 
            return [self._chunk(text) for text in texts]
        else:
            return self._chunk(texts)
        
    def _chunk(self, text):
        words = text.split(' ')
        chunks = []

        num_it = len(words) // 300 
        
        for i in range(num_it):
            chunk = ' '.join(words[i * 300:(i + 1) * 300])
            chunks.append(chunk)
        if len(words) % 300 != 0:
            chunk = ' '.join(words[num_it * 300:])
            if chunk:
                chunks.append(chunk)
        return chunks


def get_vanilla_prompt_format(data, use_answer_tag=True):
    if use_answer_tag:
        add_prompt = " Put your answer inside the answer tag, like <answer>your answer</answer>."
    else:
        add_prompt = ""

    if data in ["hotpotqa", "2wikimqa", "musique"]:
        pre = "Answer the question based on the given passages." + add_prompt
        mid = "\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words."
        post = add_prompt + "\n\nQuestion: {input}\nAnswer:"
    elif data in ["narrativeqa", "qasper", "multifieldqa_en"]:
        if not use_answer_tag:
            add_prompt = " Do not provide any explanation."
        pre = "Answer the question based on the given passages as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the given passages, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\"." + add_prompt
        mid = '\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the given passages, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\".'
        post = add_prompt + "\n\nQuestion: {input}\n\nAnswer:"
    elif data in  ["gov_report", "multi_news"]:
        add_prompt_ans = ""
        pre = "Read the following text and write a one-page summary." + add_prompt + add_prompt_ans
        mid = "\n\n{context}\n\nNow, write a one-page summary."
        post = add_prompt + add_prompt_ans + "\n\nSummary:"

    else: 
        raise NotImplementedError("Not supported data for vanilla prompting")
    
    print(f'pre: {pre}')
    print(f'mid: {mid}')
    print(f'post: {post}')
    prompt = pre + mid + post
    return prompt


def get_modelpath(model_name: str) -> str:
    if 'llama3_8b' in model_name:
        modelpath = 'meta-llama/Llama-3.1-8B-Instruct'
    elif 'llama3_70b' in model_name:
        modelpath = 'meta-llama/Llama-3.1-70B-Instruct'
    elif 'qwen_32b' in model_name:
        modelpath = "Qwen/Qwen3-32B"
    elif 'qwen_8b' in model_name:
        modelpath = "Qwen/Qwen3-8B"
    elif 'gpt_120b' in model_name:
        modelpath = "openai/gpt-oss-120b"
    elif 'gpt_20b' in model_name:
        modelpath = "openai/gpt-oss-20b"
    else:
        raise NotImplementedError(f"Provide a proper model name\nCurrent model: {model_name}")
    return modelpath

def predict(client, messages, tokenizer_kwargs, model_name, is_batch):
    if 'chat' in model_name:
        if is_batch:
            batch_messages = [chat(msg) for msg in messages]
        else:
            batch_messages = chat(messages)
    else:
        batch_messages = messages

    if 'llama' in model_name:
        terminators = [
            client.tokenizer.eos_token_id,
            client.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
            
        response = client(
            batch_messages,
            eos_token_id=terminators,
            pad_token_id=client.tokenizer.eos_token_id,
            **tokenizer_kwargs
        )
    else:
        response = client(
            batch_messages,
            **tokenizer_kwargs
        )

    if is_batch:
        results = []
        for resp in response:
            if resp and len(resp) > 0:
                generated_text = resp[0]['generated_text']
                results.append(generated_text)
                # print(aam)
            else:
                results.append("")
                print(f"Unexpected response format: {resp}")
    else:
        if response and len(response) > 0:
            generated_text = response[0]['generated_text']
            results = generated_text
        else:
            results = ""
            print(f"Unexpected response format: {response}")

    return results



def chat(prompt):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt}
            ]
        }
    ]
    return messages


def truncate(prompt, tokenizer, max_len, model_max_len=None):
    ## First, roughly truncate the prompt to fit the 2 x max length (prevent OOM)
    words = prompt.split()
    if model_max_len is None:
        model_max_len = max_len
    if len(words) > min(2*max_len, model_max_len):
        prompt = ' '.join(words[:max_len // 2] + words[-max_len // 2:])
        
    # Adopted from prediction code of longbench-v2
    # https://github.com/THUDM/LongBench/blob/main/pred.py
    input_ids = tokenizer.encode(prompt)
    if len(input_ids) > max_len:
        input_ids = input_ids[:max_len // 2] + input_ids[-max_len // 2:]
        prompt = tokenizer.decode(input_ids, skip_special_tokens=True)
    return prompt


def extract_answer_choice(response):
    response = response.replace('*', '')
    match = re.search(r'The correct answer is \(([A-D])\)', response)
    if match:
        return match.group(1)
    else:
        match = re.search(r'The correct answer is ([A-D])', response)
        if match:
            return match.group(1)
        else:
            return None
        

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)
