import os
import ast
import json
import math
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

from functools import partial
import logging
import math
from typing import Any, Callable, List, Tuple

import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig

from pyserini.search.lucene import LuceneSearcher

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'  
os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
class rank1():
    name: str = "rank1"

    def __init__(
        self,
        model_name_or_path: str,
        batch_size: int = 999999999999,
        context_size: int = 16000,
        max_output_tokens: int = 8192,
        fp_options: str = "float16",
        num_gpus: int = 1,
        device: str = "cuda",
        dataset_prompt: str = None,
    ):
        """
        rank1 is a reasoning reranker model (using test-time compute) which generates a reasoning chain before deciding true or false

        Args:
            model_name_or_path: Path to the model or name of the model on HuggingFace Hub
            batch_size: Maximum batch size for processing (default: very large number to let vLLM handle batching)
            context_size: Maximum context length for the model (default: 4096)
            max_output_tokens: Maximum number of tokens to generate (default: 1024)
            fp_options: Floating point precision to use, e.g. 'float16' (default: 'float16')
            num_gpus: Number of GPUs to use for tensor parallelism (default: 1)
            device: Device to load the model on (default: 'cuda')
        """        
        self.context_size = context_size
        self.max_output_tokens = max_output_tokens
        self.num_gpus = num_gpus
        self.device = device
        self.model_name_or_path = model_name_or_path
        self.dataset_prompt = dataset_prompt

        # Initialize tokenizer with max length of 
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.tokenizer.padding_side = "left"
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Cache commonly used token IDs
        self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0]
        self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0]
        self.think_token = self.tokenizer("<think>", add_special_tokens=False).input_ids[0]
        self.think_end_token = self.tokenizer("</think>", add_special_tokens=False).input_ids[-1]

        self.model = LLM(
            model=model_name_or_path,
            tensor_parallel_size=int(num_gpus),
            trust_remote_code=True,
            max_model_len=context_size,
            gpu_memory_utilization=0.9,
            enforce_eager=True,
            dtype=fp_options,
        )
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=max_output_tokens,
            logprobs=20,
            stop=["</think> true", "</think> false"],
            skip_special_tokens=False
        )

    def _fix_incomplete_responses(
        self, 
        original_prompts: List[str], 
        generated_texts: List[str]
    ) -> Tuple[List[str], List[int], List[float]]:
        """
        This function is used to fix incomplete responses from the vLLM model. In some cases the model does not generate the end </think> token.
            In these cases, we should force it to generate it so that we have some prediction. 

        Args:
            original_prompts: The original prompts that were used to generate the texts
            generated_texts: The texts that were generated by the vLLM model

        Returns:
            final_texts: The texts that were generated by the vLLM model + the outputs from the forcing step
            token_counts: The number of tokens in the texts total
            scores: The scores of the texts
        """             
        cleaned_texts = []
        for text in generated_texts:
            text = text.rstrip()
            if not text.endswith(('.', '!', '?')):
                last_punct = max(text.rfind('.'), text.rfind('!'), text.rfind('?'))
                if last_punct != -1:
                    text = text[:last_punct + 1]
            cleaned_texts.append(text.strip())
        
        forced_prompts = [
            f"{original_prompt}\n{cleaned_text}\n</think>" 
            for original_prompt, cleaned_text in zip(original_prompts, cleaned_texts)
        ]

        new_sampling_args = SamplingParams(
            temperature=0,
            max_tokens=1,
            logprobs=20,
            allowed_token_ids=[self.true_token, self.false_token],
            skip_special_tokens=False
        )
        outputs = self.model.generate(forced_prompts, new_sampling_args)

        # get the next token logits of just the next token
        all_final_texts = []
        all_token_counts = []
        all_scores = []    
        for i in range(len(outputs)):
            try:
                text = outputs[i].outputs[0].text
                final_logits = outputs[i].outputs[0].logprobs[-1]
                assert self.false_token in final_logits and self.true_token in final_logits, f"final logits are missing true or false: {final_logits}"
            except Exception as e:
                print(f"Error: {e} on fixing error, setting at 0.5 score: {outputs[i].outputs}")
                all_scores.append(0.5)
                all_token_counts.append(len(outputs[i].outputs[0].token_ids))
                all_final_texts.append(text)
                continue
                
            token_count = len(outputs[i].outputs[0].token_ids)
            true_logit = final_logits[self.true_token].logprob
            false_logit = final_logits[self.false_token].logprob
            true_score = math.exp(true_logit)
            false_score = math.exp(false_logit)
            score = true_score / (true_score + false_score)
            
            all_final_texts.append(text)
            all_token_counts.append(token_count)
            all_scores.append(score)
        
        return all_final_texts, all_token_counts, all_scores

    def truncate(self, text, length):
        return self.tokenizer.convert_tokens_to_string(self.tokenizer.tokenize(text)[:length])

    def _generate_model_outputs(self, prompts):
        return self.model.generate(prompts, self.sampling_params)

    def _process_with_vllm(self, prompts):
        outputs = self._generate_model_outputs(prompts)

        # Pre-allocate lists with None values
        total_length = len(prompts)
        all_outputs = [None] * total_length
        all_output_token_counts = [None] * total_length
        all_scores = [None] * total_length
        
        incomplete_prompts = []
        incomplete_texts = []
        incomplete_indices = []
        
        # Process complete responses first
        for i, output in enumerate(outputs):
            text = output.outputs[0].text
            try:
                final_logits = output.outputs[0].logprobs[-1]
            except Exception as e:
                print(f"Error: {e} on getting final logits: {output.outputs[0]}")
                incomplete_prompts.append(prompts[i])
                incomplete_texts.append(text)
                incomplete_indices.append(i)
                continue

            if self.true_token not in final_logits or self.false_token not in final_logits:
                incomplete_prompts.append(prompts[i])
                incomplete_texts.append(text)
                incomplete_indices.append(i)
                continue
                
            token_count = len(output.outputs[0].token_ids)
            true_logit = final_logits[self.true_token].logprob
            false_logit = final_logits[self.false_token].logprob
            true_score = math.exp(true_logit)
            false_score = math.exp(false_logit)
            score = true_score / (true_score + false_score)
            
            all_outputs[i] = text
            all_output_token_counts[i] = token_count
            all_scores[i] = score
        
        # Handle incomplete responses
        if incomplete_indices:
            fixed_texts, fixed_counts, fixed_scores = self._fix_incomplete_responses(
                incomplete_prompts, incomplete_texts
            )
            
            # Fill in the fixed responses at their original positions
            for orig_idx, (text, count, score) in zip(
                incomplete_indices, zip(fixed_texts, fixed_counts, fixed_scores)
            ):
                all_outputs[orig_idx] = text
                all_output_token_counts[orig_idx] = count
                all_scores[orig_idx] = score

        return all_outputs, all_output_token_counts, all_scores

    def return_prompt(self, query, doc_content, prompt) -> str:
        query = prompt.replace("FILL_QUERY_HERE", query) if prompt else query
        return "Determine if the following passage is relevant to the query. " \
                "Answer only with 'true' or 'false'.\n" \
                f"Query: {query}\n" \
                f"Passage: {doc_content}\n" \
                "<think>" 

    @torch.inference_mode()
    def predict(self, queries, passages, save_reasoning_text=False):
        """This is setup to run with mteb but can be adapted to your purpose"""

        prompts = [
            self.return_prompt(query, passage, self.dataset_prompt)
            for query, passage in zip(queries, passages)
        ]
        print(f"Example prompt: ```\n{prompts[0]}\n```")

        texts, token_counts, scores = self._process_with_vllm(prompts)
        if save_reasoning_text:
            return texts, scores 
            
        return scores