from mledojo.gym.competition import CompetitionRegistry, CompInfo
from mledojo.competitions import get_metric
from mledojo.gym.env import KaggleEnvironment
import sys
import os
import re
import argparse

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from prompt.prompts import dojo_test_instruction, dojo_test_debug_instruction
from utils.chat import ChatClient


def format_task_description(overview, sample_submission, data_structure, output_path, data_path):
    instruction = '''Task Description includes overview, sample submission, data structure, output path, and data path.'''
    return "\n".join([
        instruction,
        str(overview["feedback"]["base"]["feedback"]),
        str(sample_submission["feedback"]["base"]["feedback"]),
        str(data_structure["feedback"]["base"]["feedback"]),
        str(output_path["feedback"]["base"]["feedback"]),
        str(data_path["feedback"]["base"]["feedback"])
    ])


class DojoTest:
    def __init__(self, competition_name: str, data_dir: str, output_dir: str, gpu_device: str, 
                 model: str = "gpt-4o", temperature: float = 0.1):
        self.competition_name = competition_name
        self.data_dir = str(os.path.abspath(data_dir))
        self.output_dir = str(os.path.abspath(output_dir))
        self.gpu_device = gpu_device
        self.chat_client = ChatClient(model=model, temperature=temperature, max_completion_tokens=20000)

    def _initialize_environment(self):
        """Initialize the competition environment"""
        registry = CompetitionRegistry()
        registry.register(
            name=self.competition_name,
            data_dir=self.data_dir,
            comp_info=CompInfo(
                category="General",
                level="kaggle",
                output_type="submission.csv"
            ),
            metric_class=get_metric(self.competition_name)
        )
        
        self.env = KaggleEnvironment.make(
            competition_name=self.competition_name,
            output_dir=self.output_dir,
            competition_registry=registry,
            score_mode="raw",
            gpu_device=self.gpu_device,
            gpu_memory_limit=32,
            execution_timeout=43200
        )

    def _get_task_description(self):
        """Get task information from environment"""
        overview, _ = self.env.step("request_info", **{"info_type": "overview"})
        sample_submission, _ = self.env.step("request_info", **{"info_type": "sample_submission"})
        data_structure, _ = self.env.step("request_info", **{"info_type": "data_structure"})
        output_path, _ = self.env.step("request_info", **{"info_type": "output_path"})
        data_path, _ = self.env.step("request_info", **{"info_type": "data_path"})
        
        return format_task_description(overview, sample_submission, data_structure, output_path, data_path)

    def _format_prompt(self):
        """Format the initial prompt"""
        task_description = self._get_task_description()
        return dojo_test_instruction.format(description=task_description)

    def _extract_code(self, response: str) -> str:
        """Extract Python code from LLM response"""
        # Look for code between ```python and ```
        code_match = re.search(r'```python\s*\n(.*?)\n```', response, re.DOTALL)
        if code_match:
            return code_match.group(1).strip()
        
        # Fallback: look for code between ``` and ```
        code_match = re.search(r'```\s*\n(.*?)\n```', response, re.DOTALL)
        if code_match:
            return code_match.group(1).strip()
        
        return response.strip()

    def _execute_code(self, code: str) -> tuple[bool, str]:
        """Execute code and return success status and error message"""
        try:
            observation, reward = self.env.step("execute_code", **{"code": code})
            
            # Check action status from observation
            action_status = observation.get("action_status", "unknown")
            success = action_status == "SUCCESS"
            
            # Get feedback message
            feedback_msg = observation.get("feedback", {}).get("base", {}).get("feedback", "Unknown feedback")
            
            if success and reward is not None:
                print(f"Code executed successfully! Reward: {reward}")
                return True, f"Success! Reward: {reward}"
            elif success:
                print(f"Code executed successfully! Feedback: {feedback_msg}")
                return True, f"Success! {feedback_msg}"
            else:
                print(f"Code execution failed: {feedback_msg}")
                return False, feedback_msg
        except Exception as e:
            error_msg = str(e)
            print(f"Code execution failed with exception: {error_msg}")
            return False, error_msg

    def solve_competition(self, max_iterations: int = 10):
        """
        Solve the competition using a React agent that iteratively improves code
        """
        print(f"Starting competition solving for: {self.competition_name}")
        
        # Initialize environment
        self._initialize_environment()
        
        # Get initial prompt
        initial_prompt = self._format_prompt()
        
        messages = [{"role": "user", "content": initial_prompt}]
        error_msg = "Unknown error occurred"  # Initialize error_msg with default value
        
        for iteration in range(max_iterations):
            print(f"\n=== Iteration {iteration + 1}/{max_iterations} ===")
            
            # Get LLM response
            response = self.chat_client.chat(messages)
            if not response:
                error_msg = "Failed to get response from LLM"
                print(error_msg)
                break
            
            print(f"LLM Response received")
            
            # Extract code from response
            code = self._extract_code(response)
            if not code:
                error_msg = "No code found in LLM response"
                print(error_msg)
                break
            
            print(f"Code extracted, length: {len(code)} characters")
            
            # Execute code
            success, error_msg = self._execute_code(code)
            
            if success:
                print(f"✅ Competition solved successfully in {iteration + 1} iterations!")
                print(f"Final score: {error_msg}")
                return True, code, error_msg
            
            # If failed, prepare debug prompt for next iteration
            if iteration < max_iterations - 1:
                debug_prompt = dojo_test_debug_instruction.format(error=error_msg)
                
                # Add assistant response and user debug prompt to conversation
                messages.append({"role": "assistant", "content": response})
                messages.append({"role": "user", "content": debug_prompt})
                
                print(f"❌ Iteration {iteration + 1} failed: {error_msg}")
                print("Preparing debug prompt for next iteration...")
            else:
                print(f"❌ Max iterations reached. Final error: {error_msg}")
        
        print(f"Failed to solve competition after {max_iterations} iterations")
        return False, None, error_msg

    def run(self):
        """Run the competition solver"""
        success, final_code, result = self.solve_competition()
        
        if success:
            print("\n🎉 Competition solved successfully!")
            print(f"Result: {result}")
            
            # Optionally save the successful code
            if final_code:
                output_file = os.path.join(self.output_dir, "successful_solution.py")
                with open(output_file, 'w') as f:
                    f.write(final_code)
                print(f"Successful solution saved to: {output_file}")
        else:
            print("\n💔 Failed to solve competition")
            print(f"Final error: {result}")
        
        return success, result


def main():
    parser = argparse.ArgumentParser(description="Run dojo test for a competition")
    parser.add_argument("--competition_name", type=str, required=True, 
                       help="Name of the competition")
    parser.add_argument("--refact_index", type=str, default="1",
                       help="Index of the refact")
    parser.add_argument("--gpu_device", type=str, default="0", 
                       help="GPU device index (default: 0)")
    parser.add_argument("--output_dir", type=str, default=os.path.abspath("./dojo-test"),
                       help="Output directory path")
    parser.add_argument("--model", type=str, default="gpt-5", 
                       help="Model to use (default: gpt-5)")
    parser.add_argument("--temperature", type=float, default=1.0, 
                       help="Temperature for the model (default: 1.0)")
    
    args = parser.parse_args()
    
    competition_name = args.competition_name.split("_refact")[0]
    data_dir = os.path.abspath(f"./workplace/refact/{competition_name}/refact{args.refact_index}")
    output_dir = os.path.abspath(args.output_dir + f"/{competition_name}_refact{args.refact_index}")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    solver = DojoTest(args.competition_name, data_dir, output_dir, args.gpu_device, args.model, args.temperature)
    solver.run()


if __name__ == "__main__":
    main()