"""Enhanced Graph of Thoughts (EGoT) algorithm implementation for creative reasoning."""

import math
import json
from typing import List, Tuple, Dict, Any
try:
    from utils.llm_api_client import LLMAPIClient
    from utils.llm_response_parser import extract_json_from_response
    from data_models.task_config import TaskConfig
except ImportError:
    from src.utils.llm_api_client import LLMAPIClient
    from src.utils.llm_response_parser import extract_json_from_response
    from src.data_models.task_config import TaskConfig


class reasoning_model:
    """Enhanced Graph of Thoughts reasoning model for creative reasoning tasks."""
    
    def __init__(self, task_config: TaskConfig, backbone_llm_name: str, 
                 num_analogous_problems: int, num_solutions_per_problem: int, 
                 num_exploratory_ideas: int, num_new_rule_sets: int, 
                 num_final_solutions: int, num_solutions_combinational: int,
                 num_thoughts_per_step: int, search_depth: int):
        """Initialize the EGoT reasoning model.
        
        Args:
            task_config: TaskConfig object containing task information
            backbone_llm_name: Name of the backbone LLM to use
            num_analogous_problems: Number of analogous problems (used for context)
            num_solutions_per_problem: Number of solutions per problem (used for context)
            num_exploratory_ideas: Number of exploratory ideas (used for context)
            num_new_rule_sets: Number of new rule sets (used for context)
            num_final_solutions: Number of final solutions to generate
            num_solutions_combinational: Number of combinational solutions (used for context)
            num_thoughts_per_step: Number of thoughts per step (used for context)
            search_depth: Search depth (used for context)
        """
        # Validate inputs
        if not hasattr(task_config, 'task_description') or not hasattr(task_config, 'feasibility_check_points') or not hasattr(task_config, 'known_solutions'):
            raise ValueError("task_config must be a TaskConfig object")
        
        if not isinstance(backbone_llm_name, str) or not backbone_llm_name.strip():
            raise ValueError("backbone_llm_name must be a non-empty string")
        
        if not isinstance(num_final_solutions, int) or num_final_solutions <= 0:
            raise ValueError("num_final_solutions must be a positive integer")
        
        # Store required parameters
        self.task_config = task_config
        self.backbone_llm_name = backbone_llm_name
        self.num_leaves = 3
        self.num_final_solutions = num_final_solutions
        
        # EGoT-specific parameters with default values
        self.graph_depth = 3
        self.num_root_nodes = 1
        self.tmax = 0.7
        self.threshold_extreme = 70
        self.threshold_normal = 50
        self.e = 0.1
        
        # Initialize LLM API client
        self.llm_client = LLMAPIClient()
        
        # Initialize intermediate logs as empty list (EGoT-specific requirement)
        self.intermediate_logs: List[Tuple[str, List[Dict]]] = []
    
    def _method_node(self, prompt: str, temperature: float) -> Tuple[str, str]:
        """Execute METHODNODE logic.
        
        Args:
            prompt: Input prompt for the method node
            temperature: Temperature for LLM call (should be 0)
            
        Returns:
            Tuple of (ma, me) - method analysis and method explanation
        """
        try:
            # Construct method node prompt
            method_prompt = f"""Analyze the following problem and provide a method to solve it.

Problem: {prompt}

Please provide your response in JSON format with the following structure:
{{
    "ma": "method analysis - detailed analysis of the problem",
    "me": "method explanation - explanation of the proposed method"
}}"""
            
            # Call LLM with temperature=0
            response = self.llm_client.call_llm_model(
                prompt=method_prompt,
                model_name=self.backbone_llm_name,
                temperature=0
            )
            
            # Parse JSON response
            parsed_response = extract_json_from_response(response)
            ma = parsed_response.get("ma", "")
            me = parsed_response.get("me", "")
            
            return ma, me
            
        except Exception as e:
            raise Exception(f"METHODNODE execution failed: {e}")
    
    def _answering_node(self, prompt: str, temperature: float) -> Tuple[str, str]:
        """Execute ANSWERINGNODE logic.
        
        Args:
            prompt: Input prompt for the answering node
            temperature: Temperature for LLM call (1 for root nodes, tu for others)
            
        Returns:
            Tuple of (a, ra) - answer and rationale
        """
        try:
            # Construct answering node prompt
            answering_prompt = f"""Based on the provided context, generate a creative solution.

Context: {prompt}

Please provide your response in JSON format with the following structure:
{{
    "a": "answer - your creative solution to the problem",
    "ra": "rationale - explanation of why this solution works"
}}"""
            
            # Call LLM with specified temperature
            response = self.llm_client.call_llm_model(
                prompt=answering_prompt,
                model_name=self.backbone_llm_name,
                temperature=temperature
            )
            
            # Parse JSON response
            parsed_response = extract_json_from_response(response)
            a = parsed_response.get("a", "")
            ra = parsed_response.get("ra", "")
            
            return a, ra
            
        except Exception as e:
            raise Exception(f"ANSWERINGNODE execution failed: {e}")
    
    def _evaluation_node(self, prompt: str, temperature: float) -> Tuple[float, str, float]:
        """Execute EVALUATIONNODE logic.
        
        Args:
            prompt: Input prompt for the evaluation node
            temperature: Temperature for LLM call (should be 0)
            
        Returns:
            Tuple of (s, rs, Pr(s)) - score, reasoning, and probability
        """
        try:
            # Construct evaluation node prompt
            evaluation_prompt = f"""Evaluate the following solution based on creativity, feasibility, and utility.

Solution to evaluate: {prompt}

Please provide your response in JSON format with the following structure:
{{
    "s": <score from 0 to 100>,
    "rs": "reasoning for the score",
    "Pr(s)": <probability/confidence from 0.0 to 1.0>
}}"""
            
            # Call LLM with temperature=0
            response = self.llm_client.call_llm_model(
                prompt=evaluation_prompt,
                model_name=self.backbone_llm_name,
                temperature=0
            )
            
            # Parse JSON response
            parsed_response = extract_json_from_response(response)
            s = float(parsed_response.get("s", 0))
            rs = parsed_response.get("rs", "")
            pr_s = float(parsed_response.get("Pr(s)", 0.0))
            
            # Ensure score is between 0 and 100
            s = max(0, min(100, s))
            # Ensure probability is between 0 and 1
            pr_s = max(0.0, min(1.0, pr_s))
            
            return s, rs, pr_s
            
        except Exception as e:
            raise Exception(f"EVALUATIONNODE execution failed: {e}")
    
    def _aggregate_rationale_node(self, prompt: str, temperature: float) -> str:
        """Execute AGGREGATERATIONALENODE logic.
        
        Args:
            prompt: Input prompt for the aggregate rationale node
            temperature: Temperature for LLM call (should be 0)
            
        Returns:
            rpr - aggregated rationale and inaccurate information
        """
        try:
            # Construct aggregate rationale node prompt
            aggregate_prompt = f"""Aggregate and synthesize the following rationales, identifying any inaccurate information.

Rationales to aggregate: {prompt}

Please provide your response in JSON format with the following structure:
{{
    "rpr": "aggregated rationale and identification of inaccurate information"
}}"""
            
            # Call LLM with temperature=0
            response = self.llm_client.call_llm_model(
                prompt=aggregate_prompt,
                model_name=self.backbone_llm_name,
                temperature=0
            )
            
            # Parse JSON response
            parsed_response = extract_json_from_response(response)
            rpr = parsed_response.get("rpr", "")
            
            return rpr
            
        except Exception as e:
            raise Exception(f"AGGREGATERATIONALENODE execution failed: {e}")
    
    def _calculate_temperature(self, s: float, pr_s: float, nc: int, nt: int) -> float:
        """Calculate dynamic temperature using cosine annealing.
        
        Args:
            s: Score from evaluation node
            pr_s: Probability from evaluation node
            nc: Current depth of the path
            nt: Total possible nodes in the current path
            
        Returns:
            tu - calculated temperature for answering nodes
        """
        # Calculate c = s * (Pr(s) ** (1/e))
        c = s * (pr_s ** (1 / self.e))
        
        # Ensure c is between 0 and 1
        c = max(0.0, min(1.0, c))
        
        # Calculate tmin = 1 - sqrt(1 - (c - 1)^2)
        tmin = 1 - math.sqrt(1 - (c - 1) ** 2)
        
        # Ensure tmin is between 0 and 1
        tmin = max(0.0, min(1.0, tmin))
        
        # Calculate tu = tmin + 0.5 * (tmax - tmin) * (1 + cos(Nc / Nt))
        cosine_term = math.cos(nc / nt) if nt > 0 else 0
        tu = tmin + 0.5 * (self.tmax - tmin) * (1 + cosine_term)
        
        return tu
    
    def run(self) -> Tuple[str, List[Tuple[str, List[Dict]]]]:
        """Run the EGoT algorithm to generate creative solutions.
        
        Returns:
            Tuple of (solution_text, intermediate_logs) where intermediate_logs is empty for EGoT
        """
        try:
            all_solutions = []
            
            # Execute 3 independent graph traversals (one for each root node)
            for root_idx in range(self.num_root_nodes):
                solutions = self._traverse_graph_branch(root_idx)
                all_solutions.extend(solutions)
            
            # Select top num_final_solutions based on s * Pr(s)
            if all_solutions:
                # Sort by confidence score (s * Pr(s))
                all_solutions.sort(key=lambda x: x['s'] * x['pr_s'], reverse=True)
                top_solutions = all_solutions[:self.num_final_solutions]
                
                # Format final solution text
                solution_text = self._format_solutions(top_solutions)
            else:
                solution_text = "No valid solutions generated."
            
            # Return empty intermediate_logs as per EGoT requirements
            return solution_text, []
            
        except Exception as e:
            raise Exception(f"EGoT algorithm execution failed: {e}")
    
    def _traverse_graph_branch(self, root_idx: int) -> List[Dict[str, Any]]:
        """Traverse a single graph branch up to graph_depth with nested loops.
        
        Each node generates num_final_solutions different nodes at the next level.
        
        Args:
            root_idx: Index of the root node (0, 1, or 2)
            
        Returns:
            List of solutions from this branch
        """
        solutions = []
        
        # Start with METHODNODE
        method_prompt = f"Root {root_idx + 1}: {self.task_config.task_description}"
        ma, me = self._method_node(method_prompt, 0)
        
        # Initialize current level nodes with the method analysis
        current_level_nodes = [{'ma': ma, 'me': me, 'ra': '', 'depth': 0}]
        
        # Traverse through graph_depth levels (fixed at 3)
        for depth in range(self.graph_depth):
            print(f"Depth: {depth}")
            next_level_nodes = []
            
            # For each node at current level, generate num_final_solutions nodes
            node_index = 0
            for node in current_level_nodes:
                print(f"Depth: {depth}")
                print(f"Node Index: {node_index}")
                node_index += 1
                # Calculate Nc (current depth) and Nt (total possible nodes)
                nc = depth + 1
                nt = self.graph_depth * self.num_final_solutions
                
                # Generate num_final_solutions different ANSWERINGNODEs
                for solution_idx in range(self.num_leaves):
                    print(f"Solution Index: {solution_idx}")
                    # ANSWERINGNODE with dynamic temperature
                    if depth == 0:
                        # Root node uses temperature=1
                        tu = 1.0
                    else:
                        # Use previous evaluation results for temperature calculation
                        if solutions:
                            # Use the most recent solution for temperature calculation
                            last_solution = solutions[-1]
                            tu = self._calculate_temperature(
                                last_solution['s'], 
                                last_solution['pr_s'], 
                                nc, 
                                nt
                            )
                        else:
                            tu = 0.5  # Default temperature
                    
                    answering_prompt = f"Task Description: {self.task_config.task_description}\nMethod: {node['ma']}\nExplanation: {node['me']}\nPrevious Rationale: {node['ra']}\nDepth: {depth + 1}, Solution: {solution_idx + 1}"
                    a, ra = self._answering_node(answering_prompt, tu)
                    
                    # EVALUATIONNODE
                    evaluation_prompt = f"Solution: {a}\nRationale: {ra}"
                    s, rs, pr_s = self._evaluation_node(evaluation_prompt, 0)
                    
                    # Store solution
                    solution = {
                        'answer': a,
                        'rationale': ra,
                        's': s,
                        'rs': rs,
                        'pr_s': pr_s,
                        'root': root_idx + 1,
                        'depth': depth + 1,
                        'solution_idx': solution_idx + 1
                    }
                    solutions.append(solution)
                    
                    # AGGREGATERATIONALENODE (except for the last depth)
                    if depth < self.graph_depth - 1:
                        aggregate_prompt = f"Rationale: {ra}\nScore: {s}\nReasoning: {rs}"
                        rpr = self._aggregate_rationale_node(aggregate_prompt, 0)
                        # Create node for next level
                        next_node = {
                            'ma': node['ma'],
                            'me': node['me'],
                            'ra': rpr,
                            'depth': depth + 1
                        }
                        next_level_nodes.append(next_node)
            
            # Update current level for next iteration
            current_level_nodes = next_level_nodes
        
        return solutions
    
    def _format_solutions(self, solutions: List[Dict[str, Any]]) -> str:
        """Format the final solutions into a readable text.
        
        Args:
            solutions: List of solution dictionaries
            
        Returns:
            Formatted solution text
        """
        if not solutions:
            return "No solutions generated."
        
        formatted_text = "Enhanced Graph of Thoughts - Final Solutions:\n\n"
        
        for i, solution in enumerate(solutions, 1):
            formatted_text += f"Solution {i} (Confidence: {solution['s'] * solution['pr_s']:.2f}):\n"
            formatted_text += f"Answer: {solution['answer']}\n"
            formatted_text += f"Rationale: {solution['rationale']}\n"
            formatted_text += f"Score: {solution['s']}/100\n"
            formatted_text += f"Probability: {solution['pr_s']:.2f}\n"
            formatted_text += f"From Root: {solution['root']}, Depth: {solution['depth']}, Solution: {solution['solution_idx']}\n\n"
        
        return formatted_text
