"""Test runner for the GUI Agent"""
import argparse
import json
import logging
import os
import pickle
import re
from pathlib import Path
import base64
import io
from typing import List

from browser_env import (
    Action,
    ActionTypes,
    ScriptBrowserEnv,
    StateInfo,
    Trajectory,
    create_stop_action,
)
from browser_env.actions import is_equivalent
from browser_env.helper_functions import (
    RenderHelper,
    get_action_description,
)
from Mind2Web_evaluation.evaluator import evaluator_router
from utils.early_stop import early_stop
from utils.help_functions import is_domain_type, save_scores_to_json
from utils.training_data_collector import get_collector
from agent.llm_config import load_tool_llm


class TestRunner:
    """Handles the main test execution loop"""
    
    def __init__(self, args: argparse.Namespace, agent):
        self.args = args
        self.agent = agent
        self.logger = logging.getLogger("logger")
        # Initialize environment
        self.env = ScriptBrowserEnv(
            headless=True,
            slow_mo=args.slow_mo,
            viewport_size={
                "width": args.viewport_width,
                "height": args.viewport_height,
            },
            save_trace_enabled=args.save_trace_enabled,
            sleep_after_execution=args.sleep_after_execution,
            args=args,  # Pass args to the environment
        )
        self.planner = None
        self.agent.planner = None
            
        
    def run(self, config_file_list: list[str]):
        """Run the main test loop"""
        
        # Process each config file
        for config_file in config_file_list:
            with open(config_file, 'r') as f:
                config_data = json.load(f)
            if 'google' in config_data['start_url']:
                continue
            self._process_config_file(config_file)
        
        # Save results
        self._save_results()
        
        # Close environment
        self.env.close()
    

    
    def _process_config_file(self, config_file: str):
        """Process a single config file"""
        sub_domain = Path(config_file).stem.rsplit('_', 1)[0]
        
        action_list = []

        render_helper = RenderHelper(config_file, self.args.result_dir)

        # Get intent and task info
        with open(config_file) as f:
            _c = json.load(f)
            intent = _c["intent"]  
            intent += ' Please yield the stop action once you have completed the task.'  
            task_id = _c["task_id"]
        
        episode_id = f"{task_id}"

        numbers = re.findall(r'\d+', config_file)
        self.args.task_cnt = int(numbers[0]) if numbers else None
        self.args.hop_cnt = 0
        
        self.logger.info(f"[Config file]: {config_file}")
        self.logger.info(f"[Intent]: {intent}")
        
        self.agent.reset(config_file)
        self.agent.current_step = 0
        trajectory: Trajectory = []
        
        # Environment reset
        obs, info = self.env.reset(
            options={"config_file": config_file}, 
        )
        current_url = info["page"].url
        state_info: StateInfo = {"observation": obs, "info": info, "current_url": current_url}
        trajectory.append(state_info)
        print("CURRENT: ", current_url)
        if 'about:blank' in current_url or info["is_blocked"]:
            self.logger.info(f"[Result] (Cannot navigate to {_c['start_url']}) {config_file}")
            return
        
        meta_data = {"action_history": [],
                     "response_history": []}
        
        print("config_file: ", config_file)

        all_view_url = []
        
        # Information accumulation storage
        sub_query_answers = []
        
        # Start conversation for this task if training data collection is enabled
        if hasattr(self.agent, 'training_collector') and self.agent.training_collector:
            from utils.training_data_collector import get_collector
            collector = get_collector()
            if collector and collector.enabled:
                # Create conversation ID from task info
                conversation_id = f"{sub_domain}_{config_file.split('/')[-1].split('.')[0]}"
                collector.start_conversation(
                    conversation_id=conversation_id,
                    task_description=intent
                )
                self.logger.info(f"Started conversation collection for task: {conversation_id}")
        
        intent_list = [intent]

        # Process each sub-query sequentially
        for sub_query_idx, current_intent in enumerate(intent_list):
            
            # Enhance current intent with previous subtask results if available
            if self.args.subtask and sub_query_idx > 0 and sub_query_answers:
                enhanced_intent = self._enhance_intent_with_previous_results(
                    current_intent, sub_query_answers, sub_query_idx
                )
                self.logger.info(f"[Subtask {sub_query_idx + 1}] Enhanced intent with previous results")
            else:
                enhanced_intent = current_intent
            
                
            # Reset environment for each sub-query if not the first one
            if sub_query_idx > 0:
                self.logger.info(f"[Sub-query {sub_query_idx + 1}] Resetting environment for new sub-query")
                obs, info = self.env.reset(options={"config_file": config_file})
                current_url = info["page"].url
                state_info: StateInfo = {"observation": obs, "info": info, "current_url": current_url}
                # Clear trajectory and start fresh for new sub-query
                trajectory = [state_info]
                meta_data = {"action_history": [],
                             "page": self.env.page}
                print("CURRENT: ", current_url)
            
            # Process current sub-query
            while True:
                current_url = current_url.lower()
                all_view_url.append(current_url)

                early_stop_flag, stop_info = early_stop(
                    trajectory, self.args.max_steps, {
                        "parsing_failure": self.args.parsing_failure_th,
                        "repeating_action": self.args.repeating_action_failure_th,
                    }
                )

                if early_stop_flag:
                    # Check if fallback is enabled and should be used
                    if self.args.enable_fallback:
                        from utils.fallback_answer import generate_fallback_answer                       
                        self.logger.info(f"[Fallback] Early stop detected: {stop_info}. Generating fallback answer...")
                        
                        # Generate fallback answer using the LLM
                        fallback_result = generate_fallback_answer(
                            question=enhanced_intent,
                            trajectory=trajectory,
                            model=self.agent.llm,
                            num_screenshots=self.args.fallback_screenshots
                        )
                        # Create a stop action with the fallback answer
                        action = create_stop_action(fallback_result['answer'])
                        # Integrate the fallback answer into the trajectory
                        trajectory.append(action)
                        sub_query_answers.append((enhanced_intent, fallback_result['answer']))
                        self.logger.info(f"[Fallback] Generated fallback answer: {fallback_result['answer']}...")
                    else:
                        action = create_stop_action(f"Early stop: {stop_info}")
                else:
                    def gen_action(intent, meta, error_message=None):
                        # if error_message:
                        #     intent = f"{intent}\n[ERROR_FEEDBACK]: {error_message}"
                        action, meta =  self.agent.next_action_custom(
                            trajectory,
                            intent,
                            meta_data=meta,
                            model=self.args.loaded_model,
                            args=self.args,
                            error_message=error_message
                        )
                        return action, meta
                    
                    if self.args.action_check:
                        from utils.action_check import action_self_check
                        action, meta_data = action_self_check(gen_action, enhanced_intent, self.env.page, trajectory, meta_data, max_retries=3, repeat_threshold=self.args.repeating_action_failure_th)
                    else:
                        action, meta_data = gen_action(enhanced_intent, meta_data)
                        
                if isinstance(action, list):
                    trajectory.extend(action)
                else:
                    trajectory.append(action)
                
                action_str = get_action_description(action)
                render_helper.render(
                    action, state_info, meta_data, self.args.render_screenshot
                )
                meta_data["action_history"].append(action_str)
                meta_data["page"] = self.env.page

                if isinstance(action, list):
                    last_action_type = action[-1]["action_type"]
                else:
                    last_action_type = action["action_type"]
                if last_action_type in [ActionTypes.STOP, 'finished']:
                    self.logger.info(f"[Sub-query {sub_query_idx + 1}] Completed")
                    
                    # Store the subtask answer if using subtask decomposition
                    if self.args.subtask:
                        # Extract the answer from the stop action
                        if isinstance(action, list):
                            answer = action[-1].get('answer', '')
                        else:
                            answer = action.get('answer', '')
                        
                        # Store the subtask intent and answer
                        sub_query_answers.append((enhanced_intent, answer))
                        self.logger.info(f"[Subtask {sub_query_idx + 1}] Answer stored: {answer[:100]}...")
                    
                    break
                
                try:
                    obs, _, terminated, _, info, current_url = self.env.step(action, observation=obs)
                except Exception as e:
                    self.logger.error(f"Error in step: {e}")
                    return
                # observation, 0.0, done, truncated, info
                print("CURRENT: ", current_url)

                state_info = {"observation": obs, "info": info}
                trajectory.append(state_info)

                if terminated:
                    # add a action place holder
                    trajectory.append(create_stop_action(""))
                    self.logger.info(f"[Sub-query {sub_query_idx + 1}] Terminated")
                    
                    # Store the subtask answer if using subtask decomposition
                    if self.args.subtask:
                        sub_query_answers.append((enhanced_intent, "Task terminated without completion"))
                        self.logger.info(f"[Subtask {sub_query_idx + 1}] Terminated without completion")
                    
                    break

        self.agent.experience_memory = None
        self.agent.experience_texts, self.agent.experience_images = None, None
        # Store trajectory info
        self.trajActions[config_file] = action_list

        # evaluate the scores   
        evaluate_model = load_tool_llm(self.args)
        evaluator = evaluator_router(config_file, evaluate_model)
        score = evaluator(
            trajectory=trajectory,
            config_file=config_file,
            page=self.env.page,
            client=self.env.get_page_client(self.env.page),
        )
        last_action = trajectory[-1]
        pred = last_action.get("answer", "")
        reasoning = last_action.get("reasoning", "")
        self.logger.info(f"[Result] Predicted answer: {pred}\nReasoning: {reasoning}")
        score = 0.0 if 'Early stop' in pred else score
        self.metrics_dict[config_file] = {
                "config": config_file,
                "success": score,
            }
        self.trajSuccess[config_file] = score

        result = "PASS" if score==1 else "FAIL"
        self.logger.info(f"[Result] ({result}) {config_file}")
             
        
        render_helper.close()
        
        # End conversation for this task if training data collection is enabled
        if hasattr(self.agent, 'training_collector') and self.agent.training_collector:
            from utils.training_data_collector import get_collector
            collector = get_collector()
            if collector and collector.enabled and collector.current_conversation_id:
                # Create conversation summary
                conversation_summary = {
                    "task_id": config_file.split('/')[-1].split('.')[0],
                    "sub_domain": sub_domain,
                    "success": score,
                    "final_url": current_url,
                    "task_completed": True,
                    "task_description": intent
                }
                
                # End the conversation
                if self.args.save_examples_memory:
                    saved_file = collector.end_conversation(conversation_summary, score)
                    if saved_file:
                        self.logger.info(f"Conversation saved: {saved_file}")
    