"""Tree of Thoughts (ToT) algorithm implementation for creative reasoning."""

import datetime
import uuid
from typing import List, Tuple, Dict
from src.utils.llm_api_client import LLMAPIClient
from src.data_models.task_config import TaskConfig
from src.utils.llm_response_parser import extract_json_from_response


class reasoning_model:
    """Tree of Thoughts reasoning model for creative reasoning tasks."""
    
    def __init__(self, task_config: TaskConfig, backbone_llm_name: str, num_analogous_problems: int = 10, num_solutions_per_problem: int = 5, num_exploratory_ideas: int = 50, num_new_rule_sets: int = 3, num_final_solutions: int = 3, num_solutions_combinational: int = 10, num_thoughts_per_step: int = 5, search_depth: int = 3):
        """Initialize the reasoning model.
        
        Args:
            task_config: TaskConfig object containing task information
            backbone_llm_name: Name of the backbone LLM to use
            num_analogous_problems: Number of analogous problems (ignored for ToT)
            num_solutions_per_problem: Number of solutions per problem (ignored for ToT)
            num_exploratory_ideas: Number of exploratory ideas (ignored for ToT)
            num_new_rule_sets: Number of new rule sets (ignored for ToT)
            num_final_solutions: Number of final solutions to generate
            num_solutions_combinational: Number of combinational solutions (ignored for ToT)
            num_thoughts_per_step: Number of thoughts per step (used for ToT)
            search_depth: Search depth (used for ToT)
        """
        # Validate inputs
        if not isinstance(task_config, TaskConfig):
            raise ValueError("task_config must be a TaskConfig object")
        
        if not isinstance(backbone_llm_name, str) or not backbone_llm_name.strip():
            raise ValueError("backbone_llm_name must be a non-empty string")
        
        if not isinstance(num_final_solutions, int) or num_final_solutions <= 0:
            raise ValueError("num_final_solutions must be a positive integer")
        
        if not isinstance(num_thoughts_per_step, int) or num_thoughts_per_step <= 0:
            raise ValueError("num_thoughts_per_step must be a positive integer")
        
        if not isinstance(search_depth, int) or search_depth <= 0:
            raise ValueError("search_depth must be a positive integer")
        
        self.task_description_text = task_config.task_description
        self.num_final_solutions = num_final_solutions
        self.backbone_llm_name = backbone_llm_name
        
        # ToT-specific parameters
        self.num_thoughts_per_step = num_thoughts_per_step
        self.search_depth = search_depth
        
        # Initialize LLM API client
        self.llm_client = LLMAPIClient()
        
        # Initialize intermediate logs collection with detailed structure
        self.intermediate_logs: List[Tuple[str, List[Dict]]] = []
    
    def run(self) -> Tuple[str, List[Tuple[str, List[Dict]]]]:
        """Run the reasoning model to generate a solution using Tree of Thoughts.
        
        Returns:
            Tuple of (generated solution text, intermediate logs)
            
        Raises:
            RuntimeError: If the model fails to generate a solution
        """
        try:
            # Initialize with task description as root node
            root_context = self.task_description_text
            
            # Generate initial thoughts from root
            initial_thoughts = self._generate_thoughts(root_context, 0)
            
            # Select top thoughts from initial generation
            initial_selected = self._vote_and_select_thoughts(initial_thoughts, root_context, 0)
            
            # Create node objects with full path context
            # Each node is a tuple: (thought, full_path_context)
            current_nodes = [(thought, f"{root_context}\n\nPath to this node:\n{thought}") for thought in initial_selected]
            
            # Tree exploration: for each level in the tree
            for level in range(1, self.search_depth):
                next_level_nodes = []
                
                # For each node at current level, generate thoughts and select
                for node_idx, (node_thought, node_path_context) in enumerate(current_nodes):
                    # Generate thoughts from this node using its full path context
                    node_thoughts = self._generate_thoughts(node_path_context, level, node_idx)
                    
                    # Select top thoughts from this node
                    selected_from_node = self._vote_and_select_thoughts(node_thoughts, node_path_context, level, node_idx)
                    
                    # Create new nodes with extended path context
                    for selected_thought in selected_from_node:
                        # Build the full path context for the new node
                        extended_path_context = f"{node_path_context}\n\nNext step:\n{selected_thought}"
                        next_level_nodes.append((selected_thought, extended_path_context))
                
                # Update current nodes for next iteration
                current_nodes = next_level_nodes
                
                # Prevent infinite growth by limiting nodes per level
                if len(current_nodes) > self.num_thoughts_per_step * 2:
                    # If we have too many nodes, select the best ones
                    current_nodes = current_nodes[:self.num_thoughts_per_step * 2]
            
            # Extract final thoughts from all final nodes
            final_thoughts = [node_thought for node_thought, _ in current_nodes]
            
            # Generate final solutions based on all final nodes
            final_solution = self._generate_final_solutions(final_thoughts)
            
            return final_solution, self.intermediate_logs
            
        except Exception as e:
            raise RuntimeError(f"Error during solution generation: {e}")
    
    def _generate_thoughts(self, current_context: str, step_number: int, node_idx: int = 0) -> List[str]:
        """Generate thoughts for the current context at a given step.
        
        Args:
            current_context: Current context for thought generation
            step_number: Current step number (0-based)
            node_idx: Index of the node being processed (0-based)
            
        Returns:
            List of generated thoughts
            
        Raises:
            RuntimeError: If thought generation fails
        """
        try:
            # Generate unique identifiers for logging
            llm_call_id = str(uuid.uuid4())
            timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
            
            # Construct the thought generation prompt
            prompt = f"""You are working on a creative reasoning problem. Based on the current context, generate {self.num_thoughts_per_step} diverse and creative thoughts or ideas that could help solve this problem.

Current Context:
{current_context}

Please generate {self.num_thoughts_per_step} distinct thoughts as a JSON array of strings. Each thought should be a concise but meaningful idea that could contribute to solving the problem.

Example format:
["thought 1", "thought 2", "thought 3", "thought 4", "thought 5"]"""
            
            # Make LLM call
            raw_response = self.llm_client.call_llm_model(
                prompt=prompt,
                model_name=self.backbone_llm_name,
                temperature=0.7
            )
            
            # Parse the response to extract thoughts
            try:
                parsed_thoughts = extract_json_from_response(raw_response)
                if not isinstance(parsed_thoughts, list):
                    raise ValueError("Response is not a list")
                
                # Ensure we have the right number of thoughts
                thoughts = [str(thought).strip() for thought in parsed_thoughts[:self.num_thoughts_per_step]]
                while len(thoughts) < self.num_thoughts_per_step:
                    thoughts.append(f"Generated thought {len(thoughts) + 1}")
                
            except Exception as parse_error:
                # Fallback: create placeholder thoughts
                thoughts = [f"Generated thought {i+1}" for i in range(self.num_thoughts_per_step)]
                parsed_thoughts = None
            
            # Construct detailed log entry
            llm_call_log_entry = {
                "llm_call_id": llm_call_id,
                "prompt": prompt,
                "raw_response": raw_response,
                "llm_model_name": self.backbone_llm_name,
                "temperature": 0.7,
                "timestamp": timestamp,
                "parsed_output": parsed_thoughts
            }
            
            # Append to intermediate logs
            if step_number == 0:
                step_name = f"ToT: Root Thought Generation"
            else:
                step_name = f"ToT: Thought Generation (Level {step_number + 1}, Node {node_idx + 1})"
            self.intermediate_logs.append((step_name, [llm_call_log_entry]))
            
            return thoughts
            
        except Exception as e:
            raise RuntimeError(f"Error during thought generation at step {step_number + 1}: {e}")
    
    def _vote_and_select_thoughts(self, thoughts: List[str], current_context: str, step_number: int, node_idx: int = 0) -> List[str]:
        """Vote and select the top thoughts from the generated thoughts.
        
        Args:
            thoughts: List of thoughts to evaluate
            current_context: Current context for evaluation
            step_number: Current step number (0-based)
            node_idx: Index of the node being processed (0-based)
            
        Returns:
            List of selected top thoughts
            
        Raises:
            RuntimeError: If thought selection fails
        """
        try:
            # Generate unique identifiers for logging
            llm_call_id = str(uuid.uuid4())
            timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
            
            # Construct the voting prompt
            thoughts_text = "\n".join([f"{i+1}. {thought}" for i, thought in enumerate(thoughts)])
            num_thoughts = len(thoughts)
            num_selected_thoughts = min(self.num_final_solutions, num_thoughts)
            prompt = f"""You are evaluating creative thoughts for a problem-solving task. Based on the current context, select the {self.num_final_solutions} most promising thoughts from the list below.

Current Context:
{current_context}

Thoughts to evaluate:
{thoughts_text}

Please select the {num_selected_thoughts} most promising thoughts by returning their numbers as a JSON array of integers.

Example format (if selecting thoughts 1, 3, and 5):
[1, 3, 5]

Select the thoughts that are most likely to lead to creative and effective solutions."""
            
            # Make LLM call
            raw_response = self.llm_client.call_llm_model(
                prompt=prompt,
                model_name=self.backbone_llm_name,
                temperature=0.7
            )
            
            # Parse the response to extract selected thought indices
            try:
                selected_indices = extract_json_from_response(raw_response)
                if not isinstance(selected_indices, list):
                    raise ValueError("Response is not a list")
                
                # Convert to 0-based indices and select thoughts
                selected_thoughts = []
                for idx in selected_indices:
                    try:
                        thought_idx = int(idx) - 1  # Convert to 0-based
                        if 0 <= thought_idx < len(thoughts):
                            selected_thoughts.append(thoughts[thought_idx])
                    except (ValueError, TypeError):
                        continue
                
                # Ensure we have the right number of selected thoughts
                while len(selected_thoughts) < self.num_final_solutions and len(selected_thoughts) < len(thoughts):
                    # Add remaining thoughts if we don't have enough
                    for i, thought in enumerate(thoughts):
                        if thought not in selected_thoughts:
                            selected_thoughts.append(thought)
                            break
                
                parsed_output = selected_indices
                
            except Exception as parse_error:
                # Fallback: select first few thoughts
                selected_thoughts = thoughts[:self.num_final_solutions]
                parsed_output = None
            
            # Construct detailed log entry
            llm_call_log_entry = {
                "llm_call_id": llm_call_id,
                "prompt": prompt,
                "raw_response": raw_response,
                "llm_model_name": self.backbone_llm_name,
                "temperature": 0.7,
                "timestamp": timestamp,
                "parsed_output": parsed_output
            }
            
            # Append to intermediate logs
            if step_number == 0:
                step_name = f"ToT: Root Thought Voting"
            else:
                step_name = f"ToT: Thought Voting (Level {step_number + 1}, Node {node_idx + 1})"
            self.intermediate_logs.append((step_name, [llm_call_log_entry]))
            
            return selected_thoughts
            
        except Exception as e:
            raise RuntimeError(f"Error during thought selection at step {step_number + 1}: {e}")
    
    def _generate_final_solutions(self, accumulated_thoughts: List[str]) -> str:
        """Generate final solutions based on accumulated thoughts.
        
        Args:
            accumulated_thoughts: List of all selected thoughts from previous steps
            
        Returns:
            Final solution text
            
        Raises:
            RuntimeError: If final solution generation fails
        """
        try:
            # Generate unique identifiers for logging
            llm_call_id = str(uuid.uuid4())
            timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
            
            # Construct the final solution generation prompt
            thoughts_text = "\n".join([f"- {thought}" for thought in accumulated_thoughts])
            prompt = f"""You have been working through a creative reasoning problem using a Tree of Thoughts approach. Based on the original problem and all the accumulated thoughts from the search process, generate {self.num_final_solutions} distinct, creative, and comprehensive solutions.

Original Problem:
{self.task_description_text}

Accumulated Thoughts from Tree of Thoughts Search:
{thoughts_text}

Please provide {self.num_final_solutions} distinct creative solutions. Each solution should be well-developed and incorporate insights from the thought exploration process. Present them clearly and concisely."""
            
            # Make LLM call
            raw_response = self.llm_client.call_llm_model(
                prompt=prompt,
                model_name=self.backbone_llm_name,
                temperature=0.7
            )
            
            if not raw_response or not raw_response.strip():
                raise RuntimeError("Failed to generate final solutions - empty response from LLM")
            
            # Construct detailed log entry
            llm_call_log_entry = {
                "llm_call_id": llm_call_id,
                "prompt": prompt,
                "raw_response": raw_response,
                "llm_model_name": self.backbone_llm_name,
                "temperature": 0.7,
                "timestamp": timestamp,
                "parsed_output": None  # No complex parsing for final solution
            }
            
            # Append to intermediate logs
            step_name = "ToT: Final Solution Generation"
            self.intermediate_logs.append((step_name, [llm_call_log_entry]))
            
            return raw_response
            
        except Exception as e:
            raise RuntimeError(f"Error during final solution generation: {e}")
