import os

os.environ["VLLM_ENGINE_ITERATION_TIMEOUT_S"] = "1800"

import asyncio
import time
import re
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from loguru import logger


class LLMClientWithTools:
    """
    LLM Client with tool support for batch processing time series analysis.
    """
    
    def __init__(self,
        model_path: str = "./ckpt",
        max_model_len: int = 12000,
        tensor_parallel: int = 1,
        max_tool_iters: int = 15,
        system_prompt: str = "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format, with the answer included between the <answer> and </answer> tags: <think>\n...\n</think>\n<answer>\nYour answer\n</answer>",
        temperature=0.2
    ):
        """
        Initialize the LLM client with tools.
        
        Args:
            model_path: Path to the model
            max_model_len: Maximum model length
            tensor_parallel: Tensor parallel size
            max_tool_iters: Maximum tool iterations
            system_prompt: System prompt for the assistant
        """
        self.model_path = model_path
        self.max_model_len = max_model_len
        self.tensor_parallel = tensor_parallel
        self.max_tool_iters = max_tool_iters
        self.system_prompt = system_prompt
        
        # Current context for tool access (will be updated per batch)
        # Global variables to store current context
        self.current_messages_history = []
        self.current_batch_to_global_mapping = []  # Maps current batch indices to global indices
        self.current_timeseries_data = []
        self.current_questions = []  # Store current batch questions for tool access
        
        # Initialize tools
        self.tools = [{
            "type": "function",
            "function": {
                "name": "get_timeseries_slice",
                "description": "Get the current timeseries_slice of one of the time series in a given location, you should call this tool during thinking to better recognize the local fluctuations of a given period",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "metric_name": {"type": "string", "description": "The name of the metric to get the timeseries slice for"},
                        "start": {"type": "integer", "description": "The start index of the timeseries slice"},
                        "end": {"type": "integer", "description": "The end index of the timeseries slice"}
                    },
                    "required": ["metric_name", "start", "end"]
                }
            }
        }, {
            "type": "function",
            "function": {
                "name": "compare_timeseries_slice",
                "description": "Compare two slices of timeseries data from potentially different metrics. Use this tool to analyze relationships, correlations, or differences between different timeseries segments.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "metric_name_1": {"type": "string", "description": "The name or identifier of the first timeseries metric to slice. This should match or be contained in the metric names mentioned in the conversation."},
                        "start_1": {"type": "integer", "description": "The starting index (inclusive) for the first timeseries slice. Must be >= 0."},
                        "end_1": {"type": "integer", "description": "The ending index (exclusive) for the first timeseries slice. Must be > start_1."},
                        "metric_name_2": {"type": "string", "description": "The name or identifier of the second timeseries metric to slice. This should match or be contained in the metric names mentioned in the conversation."},
                        "start_2": {"type": "integer", "description": "The starting index (inclusive) for the second timeseries slice. Must be >= 0."},
                        "end_2": {"type": "integer", "description": "The ending index (exclusive) for the second timeseries slice. Must be > start_2."}
                    },
                    "required": ["metric_name_1", "start_1", "end_1", "metric_name_2", "start_2", "end_2"]
                }
            }
        }]
        
        self.tool_functions = {
            "get_timeseries_slice": self._execute_get_timeseries_slice,
            "compare_timeseries_slice": self._execute_compare_timeseries_slice
        }
        self.temperature = temperature
        
        # Initialize model and tokenizer
        self._initialize_model()
    
    def _initialize_model(self):
        """Initialize the VLLM model and tokenizer."""
        import thinktime.vllm.chatts_vllm
        
        self.model = LLM(
            model=self.model_path, 
            enforce_eager=True, 
            gpu_memory_utilization=0.95, 
            max_model_len=self.max_model_len, 
            tensor_parallel_size=self.tensor_parallel, 
            trust_remote_code=True, 
            disable_custom_all_reduce=True,
            enable_prefix_caching=False,
            dtype='float16',
            limit_mm_per_prompt={"timeseries": 50}
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
    
    def _hash_str(self, text: str) -> str:
        """Generate a short hash string for debugging purposes."""
        import hashlib
        if not text:
            return "empty"
        return hashlib.md5(text.encode()).hexdigest()[:8]
    
    def _execute_get_timeseries_slice(self, metric_name: str, start: int, end: int, user_question: str, timeseries_list: list) -> dict:
        """
        🔧 IMPLEMENTED: Execute get_timeseries_slice tool based on vllm_tool_using.py logic.
        
        Searches through conversation history to find timeseries data for the specified metric.
        """
        
        try:
            target_timeseries = None
            first_timeseries = None

            # Use the user_question directly instead of extracting from conversation history
            if timeseries_list:
                first_timeseries = timeseries_list[0]

            # Search for the metric name in the user question (case insensitive, partial match)
            if user_question and metric_name.lower() in user_question.lower():
                # Find which timeseries index corresponds to this metric
                # Split by <ts><ts/> markers as in vllm_tool_using.py
                lines = user_question.split('<ts><ts/>')
                
                for i, line in enumerate(lines):
                    if metric_name.lower() in line.lower():
                        # Simple heuristic - assume order in text matches order in timeseries data
                        ts_index = i
                        if ts_index < len(timeseries_list):
                            target_timeseries = timeseries_list[ts_index]
                            break
                        else:
                            logger.warning(f"⚠️ [TIMESERIES INDEX] Index {ts_index} out of range (have {len(timeseries_list)} timeseries)")
            
            # If no timeseries found in conversation, return error
            if target_timeseries is None:
                error_msg = f"Error: Metric '{metric_name}' not found in conversation context"
                logger.warning(f"❌ [TIMESERIES NOT FOUND] {error_msg}, fall back to the first timeseries, context: {user_question}")

                if first_timeseries is not None:
                    target_timeseries = first_timeseries
                    logger.warning(f"⚠️ [FALLBACK] Using first available timeseries for metric '{metric_name}'")
                    metric_name = f"the first available metric (the provided {metric_name} was not found, please check the provided metric name)"
                else:
                    return {
                        "text": error_msg
                    }
            
            # Validate and correct bounds
            ts_length = len(target_timeseries)
            original_start, original_end = start, end
            start = max(0, min(start, ts_length - 1))
            end = max(start + 1, min(end, ts_length))
            
            if original_start != start or original_end != end:
                logger.warning(f"⚠️ [BOUNDS ADJUSTED] From [{original_start}:{original_end}] to [{start}:{end}] for length {ts_length}")
            
            # Extract the slice
            slice_data = target_timeseries[start:end]
            
            result_text = f"<tool_response>\nThe slice of {metric_name} from {start} to {end} is: <ts><ts/>.\n</tool_response>"
            return [
                {"type": "text", "text": result_text},
                {"timeseries": slice_data}  # Return as list for consistency with multimodal format
            ]
            
        except Exception as e:
            error_msg = f"<tool_response>\nError processing timeseries slice for {metric_name}: {str(e)}\n</tool_response>"
            logger.error(f"❌ [TIMESERIES EXCEPTION] {error_msg}")
            return [
                {"type": "text", "text": error_msg}
            ]

    def _execute_compare_timeseries_slice(
        self, 
        metric_name_1: str, 
        start_1: int, 
        end_1: int,
        metric_name_2: str, 
        start_2: int, 
        end_2: int,
        user_question: str,
        timeseries_list: list
    ) -> dict:
        """
        🔧 IMPLEMENTED: Execute compare_timeseries_slice tool based on vllm_tool_using.py logic.
        
        Compares two timeseries slices from the conversation context.
        """
        
        try:
            # Helper function to find timeseries by metric name using user_question
            def find_timeseries_by_metric(metric_name: str):
                # First check in user_question
                if user_question and metric_name.lower() in user_question.lower():
                    lines = user_question.split('<ts><ts/>')
                    for i, line in enumerate(lines):
                        if metric_name.lower() in line.lower():
                            if i < len(timeseries_list):
                                return timeseries_list[i]
                return None

            # Get first timeseries as fallback
            first_timeseries = timeseries_list[0]

            # Find both timeseries
            target_timeseries_1 = find_timeseries_by_metric(metric_name_1)
            target_timeseries_2 = find_timeseries_by_metric(metric_name_2)
            
            # Handle missing timeseries with fallback logic
            if target_timeseries_1 is None:
                logger.warning(f"⚠️ [COMPARE FALLBACK] Metric '{metric_name_1}' not found, using first available timeseries")
                target_timeseries_1 = first_timeseries
                metric_name_1 = f"the first available metric (requested {metric_name_1} was not found)"
                
            if target_timeseries_2 is None:
                logger.warning(f"⚠️ [COMPARE FALLBACK] Metric '{metric_name_2}' not found, using first available timeseries")
                target_timeseries_2 = first_timeseries
                metric_name_2 = f"the first available metric (requested {metric_name_2} was not found)"
            
            if target_timeseries_1 is None or target_timeseries_2 is None:
                error_msg = f"Error: Could not find timeseries data for comparison"
                logger.error(f"❌ [COMPARE TIMESERIES] {error_msg}")
                return {
                    "text": error_msg,
                    "timeseries": None
                }
            
            # Validate and correct bounds for first timeseries
            ts_length_1 = len(target_timeseries_1)
            original_start_1, original_end_1 = start_1, end_1
            start_1 = max(0, min(start_1, ts_length_1 - 1))
            end_1 = max(start_1 + 1, min(end_1, ts_length_1))
            
            # Validate and correct bounds for second timeseries  
            ts_length_2 = len(target_timeseries_2)
            original_start_2, original_end_2 = start_2, end_2
            start_2 = max(0, min(start_2, ts_length_2 - 1))
            end_2 = max(start_2 + 1, min(end_2, ts_length_2))
            
            if (original_start_1 != start_1 or original_end_1 != end_1 or 
                original_start_2 != start_2 or original_end_2 != end_2):
                logger.warning(f"⚠️ [COMPARE BOUNDS] Adjusted bounds for comparison")
            
            # Extract the slices
            slice_data_1 = target_timeseries_1[start_1:end_1]
            slice_data_2 = target_timeseries_2[start_2:end_2]
            
            result_text = (f"<tool_response>\nComparison between {metric_name_1}[{start_1}:{end_1}] and "
                          f"{metric_name_2}[{start_2}:{end_2}]: <ts><ts/> vs <ts><ts/>\n</tool_response>")
            
            return [
                {"type": "text", "text": result_text},
                {"timeseries": slice_data_1}, 
                {"timeseries": slice_data_2}
            ]
            
        except Exception as e:
            error_msg = f"</tool_response>\nError comparing timeseries slices: {str(e)}\n</tool_response>"
            logger.error(f"❌ [COMPARE TIMESERIES EXCEPTION] {error_msg}")
            return [
                {"type": "text", "text": error_msg}
            ]
    
    def _parse_and_execute_tool_calls(self, llm_output: str, batch_idx: int = 0, user_question: str = "", timeseries_list: list = []) -> str:
        """
        Parse tool calls from LLM output and execute them.
        Only parse if the last token (before <|im_end|>) is </tool_call>
        Returns formatted results for input back to LLM.
        
        Args:
            llm_output: The LLM's output text that may contain tool calls
            batch_idx: Index in the current batch
            user_question: The original user question for this batch item
        
        Returns:
            String formatted as <tool_call>\n{result}\n</tool_call> for the latest complete tool call
        """
        try:
            # Clean the output by removing <|im_end|> if present
            cleaned_output = llm_output.replace('<|im_end|>', '').strip()
            
            # Check if the last token is </tool_call>
            if not cleaned_output.endswith('</tool_call>'):
                return "No complete tool call found. Tool calls must end with </tool_call>"
            
            # Find the latest (last) complete tool call block
            tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
            tool_call_matches = re.findall(tool_call_pattern, cleaned_output, re.MULTILINE | re.DOTALL)
            
            if not tool_call_matches:
                return "No valid tool call blocks found"
            
            # Get the latest (last) tool call
            latest_tool_call_content = json.loads(tool_call_matches[-1].strip())
            
            # Try JSON-style function calls
            function_name = latest_tool_call_content['name']
            args = latest_tool_call_content.get('arguments', {})
            
            # Add assistant_text parameter and user_question for the new tool functions
            args['user_question'] = user_question  # Pass the original user question
            args['timeseries_list'] = timeseries_list
            
            result = self.tool_functions[function_name](**args)
            return result
            
        except Exception as err:
            logger.error(f"❌ [TOOL CALL ERROR] {str(err)}")
            return f"Error tool calls: {str(err)}"
    
    def _generate_batch(self, messages_list: List[List[Dict]], timeseries_list: List[List[List[float]]], remove_think_token: bool=False) -> List[str]:
        """
        Generate responses for a batch of messages with timeseries data.
        
        Args:
            messages_list: List of message histories
            timeseries_list: List of timeseries data for each message
        
        Returns:
            List of generated responses
        """
        all_inputs = []

        logger.success(f"{len(messages_list)=}, {len(timeseries_list)=}, {[len(item) for item in timeseries_list]=}")

        for i, (messages, timeseries) in enumerate(zip(messages_list, timeseries_list)):
            # Apply chat template to see the formatted prompt
            num_timeseries = len(timeseries) if timeseries else 0
            try:
                formatted_prompt = self.tokenizer.apply_chat_template(
                    messages,
                    tools=self.tools,
                    tokenize=False, 
                    add_generation_prompt=True
                )
                if remove_think_token:
                    if formatted_prompt.endswith("<think>\n"):
                        formatted_prompt = formatted_prompt[:-len("<think>\n")]
                if i == 0:
                    logger.success(f"[Chat Template Applied ({num_timeseries=})] {formatted_prompt}")
            except Exception as e:
                logger.warning(f"[Chat Template Error]: {e}")
            
            # Use generate method instead of chat method
            if num_timeseries > 0:
                inputs = {
                    "prompt": formatted_prompt,
                    "multi_modal_data": {"timeseries": timeseries}
                }
            else:
                inputs = {
                    "prompt": formatted_prompt
                }
            
            all_inputs.append(inputs)
        
        # Generate responses for the batch
        sampling_params = SamplingParams(
            n=1,
            repetition_penalty=1.0,
            temperature=self.temperature,
            top_p=1.0,
            top_k=-1,
            min_p=0.0,
            max_tokens=2048,
            guided_decoding=None,
        )
        
        outputs = self.model.generate(all_inputs, sampling_params, use_tqdm=True)
        
        # Extract responses
        responses = []
        for i, output in enumerate(outputs):
            if hasattr(output, "outputs") and len(output.outputs) > 0:
                response_text = output.outputs[0].text
                responses.append(response_text)
                
                # # Log example responses
                # if i < 3:  # Log first 3 examples
                #     logger.warning(f"\n📥 [COMPLETION EXAMPLE {i}] ===================================================================")
                #     logger.warning(f"📤 [COMPLETION EXAMPLE {i}] <<<<<<<<<<<<<<<< Full completion text:\n{response_text}\n")
            else:
                responses.append("")
        
        return responses
    
    def llm_batch_generate(
        self, 
        questions: List[str], 
        input_timeseries_list: List[List[List[float]]]
    ) -> Tuple[List[str], List[List[Dict]]]:
        """
        Batch generate responses for multiple questions with timeseries data.
        
        Args:
            questions: List of question strings
            input_timeseries_list: List of timeseries data, each element is List[List[float]] for MTS
        
        Returns:
            Tuple of (responses, messages_histories)
            - responses: List of final response strings
            - messages_histories: List of message histories for each item
        """
        batch_size = len(questions)
        
        # Store questions for tool access
        self.current_questions = questions
        
        # Initialize message histories for each item in the batch
        messages_histories = []
        for _ in range(batch_size):
            messages_histories.append([{ "role": "system", "content": self.system_prompt}])
        
        # Set global state for tool function access
        self.current_messages_history = messages_histories
        self.current_batch_to_global_mapping = list(range(batch_size))  # Initial 1:1 mapping
        
        # Build user messages with multimodal content
        for i, (question, input_timeseries) in enumerate(zip(questions, input_timeseries_list)):
            user_content = [{"type": "text", "text": question}]
            
            # Add timeseries data to the current user message if available
            if input_timeseries is not None:
                for ts in input_timeseries:
                    user_content.append({"timeseries": ts})
            
            # Add current question with multimodal content to history
            user_message = {
                "role": "user", 
                "content": user_content
            }
            messages_histories[i].append(user_message)
            
            # Debug: Show current metric mappings
            # if input_timeseries is not None:
            #     logger.info(f"Processing question {i} with {len(input_timeseries)} timeseries")
        
        # Generate initial responses
        responses = self._generate_batch(messages_histories, input_timeseries_list)
        
        # Add assistant responses to histories
        for i, response in enumerate(responses):
            assistant_message = {"role": "assistant", "content": "<think>\n" + response}
            messages_histories[i].append(assistant_message)
        
        # Auto tool-calling loop for each item in the batch
        # Track which items still need tool calling
        active_indices = list(range(batch_size))
        
        for tool_iter in range(self.max_tool_iters):
            if not active_indices:
                break
                
            logger.success(f"--> Tool iteration {tool_iter + 1}: Processing {len(active_indices)} active items out of {batch_size}")
            
            # Find which items need tool calls in this iteration
            items_with_tool_calls = []
            
            for active_batch_idx, global_idx in enumerate(active_indices):
                # Get the original user question for this global_idx
                user_question = self.current_questions[global_idx] if global_idx < len(self.current_questions) else ""
                tool_results = self._parse_and_execute_tool_calls(
                    responses[global_idx], 
                    batch_idx=active_batch_idx, 
                    user_question=user_question,
                    timeseries_list=input_timeseries_list[global_idx] if global_idx < len(input_timeseries_list) else []
                )
                
                # Handle tool results: multimodal (list) or plain string
                if tool_results and (
                    (isinstance(tool_results, list)) or
                    (isinstance(tool_results, str) and not tool_results.startswith("Error") and not tool_results.startswith("No"))
                ):
                    items_with_tool_calls.append(global_idx)
                    
                    # Prepare tool message content
                    tool_content = tool_results if isinstance(tool_results, list) else tool_results
                    tool_message = {"role": "user", "content": tool_content}
                    messages_histories[global_idx].append(tool_message)
                    
                    # Add new timeseries data if present
                    if isinstance(tool_content, list):
                        for item in tool_content:
                            if isinstance(item, dict) and "timeseries" in item:
                                input_timeseries_list[global_idx].append(item["timeseries"])
            
            # Only continue with items that have tool calls
            if items_with_tool_calls:
                # Update active indices to only those with tool calls
                active_indices = items_with_tool_calls
                
                logger.info(f"Continuing with {len(active_indices)} items that need tool calls: {active_indices}")
                
                # Prepare batch data only for active items
                active_messages = [messages_histories[i] for i in active_indices]
                active_timeseries = [input_timeseries_list[i] for i in active_indices]
                
                # Update global state with current active items and mapping
                self.current_messages_history = active_messages
                self.current_batch_to_global_mapping = active_indices  # Map batch idx to global idx
                
                # Generate responses only for active items
                active_responses = self._generate_batch(active_messages, active_timeseries, remove_think_token=True)
                
                # Map responses back to original indices
                for batch_idx, (global_idx, response) in enumerate(zip(active_indices, active_responses)):
                    responses[global_idx] = response
                    assistant_message = {"role": "assistant", "content": response}
                    messages_histories[global_idx].append(assistant_message)
            else:
                # No more tool calls for any item, break the loop
                logger.info(f"No more tool calls needed. Finishing at iteration {tool_iter + 1}")
                break
        
        # Apply chat template for all responses
        final_results = []
        for i in range(len(responses)):
            cur_messages_history = messages_histories[i]
            
            # Remove the system prompt and first users question
            cur_messages_history = cur_messages_history[2:]
            final_results.append(self.tokenizer.apply_chat_template(cur_messages_history, tokenize=False, add_generation_prompt=False).split('<|im_start|>assistant\n', 1)[-1].rsplit('<|im_end|>', 1)[0])

        return final_results
    

def main():
    """Example usage of LLMClientWithTools."""
    # Initialize the client
    client = LLMClientWithTools()
    
    # Example batch data
    questions = [
        "I have a time series called TS1: <ts><ts/>. What is the anoamlies in this time series data?" for _ in range(100)
    ]
    
    # Example timeseries data (replace with actual data)
    input_timeseries_list = [
        [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] for _ in range(100)
    ]
    
    # Generate responses
    responses = client.llm_batch_generate(questions, input_timeseries_list)
    
    # Print results
    for i, response in enumerate(responses[:5]):
        print(f"\n=== Question {i+1} ===")
        print(f"Response: {response}")

if __name__ == "__main__":
    main()
