import os
import json
import numpy as np
from comment import Comment
from conductor import Conductor
from reducer import Reducer
from evaluator import Evaluator
from experts import (
    ModelingExpert, 
    ProgrammingExpert,
    LPFileGenerator,
    ModelingKnowledgeSupplementExpert,
    ParameterExtractor,
    CodeReviewer,
    ProgrammingExampleProvider,
    TerminologyInterpreter,
)
from comment_pool import CommentPool
from utils import extract_code_from_string


def chain_of_experts(problem, 
                     max_collaborate_nums=3, 
                     model_name="deepseek-ai/DeepSeek-V3", 
                     enable_reflection=False,
                     max_trials=1):
    """
    Run Chain of Experts pipeline
    
    Args:
        problem: a dict of problem_description and code_example.
        max_collaborate_nums: Maximum number of expert collaborations per trial
        model_name: LLM model name
        enable_reflection: Whether to enable backward reflection
        max_trials: Maximum number of forward-backward trials
    
    Return:
        code: generated code string
    """
    # Initialize all experts with the specified model
    all_experts = [
        TerminologyInterpreter(model_name),
        ParameterExtractor(model_name),
        ModelingExpert(model_name),
        ProgrammingExampleProvider(model_name),
        ProgrammingExpert(model_name),
        # LPFileGenerator(model_name),  # Optionally enabled
        ModelingKnowledgeSupplementExpert(model_name),
        CodeReviewer(model_name),
    ]
    
    num_experts = len(all_experts)
    reducer = Reducer(model_name)
    comment_pool = CommentPool(all_experts, visible_matrix=np.ones((num_experts, num_experts)))
    conductor = Conductor(model_name)
    evaluator = Evaluator(model_name)
    expert_stack = []

    for trial in range(max_trials):
        print(f"Starting trial {trial + 1}/{max_trials}")
        
        # Forward thought construction
        for step in range(max_collaborate_nums):
            try:
                next_expert = conductor.forward(problem, comment_pool, max_collaborate_nums)
                print(f'Choose next expert: {next_expert.name}')
                
                comment_text = next_expert.forward(problem, comment_pool)
                print(f'Given comment by {next_expert.name}:\n{comment_text[:200]}...')
                
                comment_pool.add_comment(Comment(next_expert, comment_text))
                expert_stack.append(next_expert)
                
            except Exception as e:
                print(f"Error in expert {next_expert.name}: {str(e)}")
                continue
        
        # Generate final answer using reducer
        try:
            answer = reducer.forward(problem, comment_pool)
            print("Reducer generated final answer")
        except Exception as e:
            print(f"Error in reducer: {str(e)}")
            answer = "# Error in generating final answer"

        # Extract code from answer
        code = extract_code_from_string(answer)
        
        # Reflection mechanism (if enabled)
        if enable_reflection:
            try:
                # Generate test sample
                test_sample = evaluator.forward(problem)
                print(f'Generated test sample for reflection')
                test_samples = [test_sample]
                
                # Evaluate the generated code
                feedback = evaluator.evaluate(test_samples, code)
                
                if feedback is not None:
                    print("Found issues, starting backward reflection...")
                    feedback_pool = CommentPool(all_experts, visible_matrix=np.ones((num_experts, num_experts)))
                    feedback_pool.add_comment(Comment(evaluator, feedback))
                    
                    # Backward reflection
                    reflection_success = False
                    while expert_stack and not reflection_success:
                        try:
                            previous_expert = expert_stack.pop()
                            previous_comment = comment_pool.pop_comment()
                            
                            result = previous_expert.backward(feedback_pool)
                            result = json.loads(result)
                            
                            if result['is_caused_by_you']:
                                # Expert admits fault and provides refined result
                                previous_comment.comment_text = result['refined_result']
                                expert_stack.append(previous_expert)
                                comment_pool.add_comment(previous_comment)
                                reflection_success = True
                                print(f"Expert {previous_expert.name} provided refinement")
                                break
                            else:
                                # Expert denies fault, continue backtracking
                                feedback_pool.add_comment(Comment(previous_expert, result['reason']))
                                
                        except Exception as e:
                            print(f"Error in backward reflection: {str(e)}")
                            break
                    
                    if reflection_success:
                        # Regenerate answer with refined comments
                        try:
                            answer = reducer.forward(problem, comment_pool)
                            code = extract_code_from_string(answer)
                            print("Refined answer generated")
                        except Exception as e:
                            print(f"Error in generating refined answer: {str(e)}")
                else:
                    print("No issues found, using original answer")
                    break
                    
            except Exception as e:
                print(f"Error in reflection mechanism: {str(e)}")
                # Continue with original answer if reflection fails
                break
        else:
            # No reflection, use the answer from first trial
            break
    
    return answer


if __name__ == '__main__':
    from utils import read_problem
    
    # Test with sample problem
    problem = read_problem('LPWP', 'prob_250')
    result = chain_of_experts(
        problem, 
        max_collaborate_nums=3,
        model_name='deepseek-ai/DeepSeek-V3', 
        enable_reflection=False,
        max_trials=1
    )

    print("Final result:", result)