"""Test runner for the GUI Agent"""
import argparse
import json
import logging
import os
import pickle
import re
import time
from pathlib import Path
import base64
import io
from typing import List

from browser_env import (
    Action,
    ActionTypes,
    ScriptBrowserEnv,
    StateInfo,
    Trajectory,
    create_stop_action,
)
from browser_env.actions import is_equivalent
from browser_env.helper_functions import (
    RenderHelper,
    get_action_description,
)
from MMInA_evaluation.evaluator import evaluator_router
from utils.early_stop import early_stop
from utils.help_functions import is_domain_type, save_scores_to_json
from agent.llm_config import load_tool_llm


class TestRunner:
    """Handles the main test execution loop"""
    
    def __init__(self, args: argparse.Namespace, agent):
        self.args = args
        self.agent = agent
        self.logger = logging.getLogger("logger")
        # Initialize environment
        self.env = ScriptBrowserEnv(
            headless=True,
            slow_mo=args.slow_mo,
            viewport_size={
                "width": args.viewport_width,
                "height": args.viewport_height,
            },
            save_trace_enabled=args.save_trace_enabled,
            sleep_after_execution=args.sleep_after_execution,
            args=args,  # Pass args to the environment
        )
        self.planner = None
        self.agent.planner = None
        
        # Initialize token and timing tracking
        self.total_input_tokens = 0
        self.config_processing_times = {}
        self.config_token_counts = {}
        
        # Load existing token timing data if available
        self._load_existing_token_timing_data()
    
    def _track_tokens(self, llm_wrapper):
        """Track input tokens from LLM wrapper's last usage info"""
        # try:
        prompt_tokens = llm_wrapper.last_usage['prompt_tokens']
        self.total_input_tokens += prompt_tokens
        return prompt_tokens
    
    def _load_existing_token_timing_data(self):
        """Load existing token timing data if available for resuming"""
        token_timing_path = Path(self.args.result_dir) / "token_timing_data.json"
        if token_timing_path.exists():
            try:
                with open(token_timing_path, 'r') as f:
                    existing_data = json.load(f)
                
                # Restore the tracking variables
                self.total_input_tokens = existing_data.get('total_input_tokens', 0)
                self.config_processing_times = existing_data.get('config_processing_times', {})
                self.config_token_counts = existing_data.get('config_token_counts', {})
                
                self.logger.info(f"[Resume] Loaded existing token timing data")
                self.logger.info(f"[Resume] Total tokens: {self.total_input_tokens}")
                self.logger.info(f"[Resume] Configs already processed: {len(self.config_processing_times)}")
                self.logger.info(f"[Resume] Total time so far: {sum(self.config_processing_times.values()):.2f} seconds")
                
            except Exception as e:
                self.logger.warning(f"Could not load existing token timing data: {e}")
                # Reset to defaults if loading fails
                self.total_input_tokens = 0
                self.config_processing_times = {}
                self.config_token_counts = {}
    
    def run(self, config_file_list: list[str]):
        """Run the main test loop"""
        
        # Process each config file
        for config_file in config_file_list:
            self._process_config_file(config_file)
        
        # Close environment
        self.env.close()
    
    def _process_config_file(self, config_file: str):
        """Process a single config file"""
        # Start timing for this config file
        start_time = time.time()
        config_tokens = 0
        
        sub_domain = config_file.replace('MMInA/','').split('/')[1]
        
        action_list = []

        render_helper = RenderHelper(config_file, self.args.result_dir)

        # Get intent and task info
        with open(config_file) as f:
            _c = json.load(f)
            intent = _c["intent"]    
            intent = intent.replace('https://library.kiwix.org/iewer#wikipedia_en_all_maxi_2024-01/A/User%3AThe_other_Kiwix_guy/Landing','https://www.wikipedia.org/')
            intent = intent.replace("http://localhost:7770/", "http://ec2-3-146-212-252.us-east-2.compute.amazonaws.com:7770/")
            if 'shopping' in intent:
                intent += "\n\n***NOTE: Please 1. directly search the product name 2.find the query product and click it 3. carefully check the product image and description to answer the question***"
            if 'compare' in config_file or 'normal' in config_file or 'multi567' in config_file or 'multipro' in config_file:
                intent = intent.replace('the action is finished just after click the search button!',
                                        'the action is finished when you successfully search the tarket hotel/ticket/flight/... with correct city, number of people, dates, etc.')
                intent += """***NOTE***: 
                1. Please always jump to multiple websites using the goto_url tool to complete the task. 
                2. Make sure you finish all the tasks. If you need to book a hotel/ticket/flight/..., you need to search the city, number of people, dates, etc.
                3. The task is finished until you see the final target, if you need to book a hotel/ticket/flight/...,  only when you see the destination hotel/ticket/flight/... with correct number of people, dates, etc. you can think the task is finished.
                4. Attention: If you think all the actions had been done, return the final url as the answer!!!
                """
            task_id = _c["task_id"]
            site = _c["sites"][0]
        
        episode_id = f"{site}_{task_id}"

        numbers = re.findall(r'\d+', config_file)
        self.args.task_cnt = int(numbers[0]) if numbers else None
        self.args.hop_cnt = 0
        
        self.logger.info(f"[Config file]: {config_file}")
        self.logger.info(f"[Intent]: {intent}")
        
        self.agent.reset(config_file)
        self.agent.current_step = 0
        trajectory: Trajectory = []
        
        # Environment reset
        obs, info = self.env.reset(
            options={"config_file": config_file}, 
        )
        current_url = info["page"].url
        state_info: StateInfo = {"observation": obs, "info": info, "current_url": current_url}
        trajectory.append(state_info)
        print("CURRENT: ", current_url)
        
        meta_data = {"action_history": [],
                     "response_history": []}
        
        print("config_file: ", config_file)
        check_list = []
        
        cnt_ans = 0
        
        # Determine domain type and setup
        if is_domain_type(sub_domain, '2hop'):
            cnt_cl = 2
            with open(config_file, "r") as file:
                data = json.load(file)
                reference = data['eval']['reference_answers']['must_include']
            nxt = reference[0]
        elif is_domain_type(sub_domain, 'multihop'):
            with open(config_file, "r") as file:
                data = json.load(file)
                check_list = data['procedure']
                city = data['city']
                flight = data['flight']
                shop = data['shop']
            cnt_cl = len(check_list)
        elif is_domain_type(sub_domain, 'singlehop'):
            cnt_cl = 1

        # Calculate sum of total hop numbers of all tasks
        
        flag = True
        all_view_url = []
        count = 0
        
        # Information accumulation storage
        sub_query_answers = []
        
        # Start conversation for this task if training data collection is enabled
        if hasattr(self.agent, 'training_collector') and self.agent.training_collector:
            from utils.training_data_collector import get_collector
            collector = get_collector()
            if collector and collector.enabled:
                # Create conversation ID from task info
                conversation_id = f"{sub_domain}_{config_file.split('/')[-1].split('.')[0]}"
                collector.start_conversation(
                    conversation_id=conversation_id,
                    task_description=intent
                )
                self.logger.info(f"Started conversation collection for task: {conversation_id}")
        
        intent_list = [intent]

        # Process each sub-query sequentially
        for sub_query_idx, current_intent in enumerate(intent_list):
            
            # Enhance current intent with previous subtask results if available
            if self.args.subtask and sub_query_idx > 0 and sub_query_answers:
                enhanced_intent = self._enhance_intent_with_previous_results(
                    current_intent, sub_query_answers, sub_query_idx
                )
                self.logger.info(f"[Subtask {sub_query_idx + 1}] Enhanced intent with previous results")
            else:
                enhanced_intent = current_intent
            
            # Reset environment for each sub-query if not the first one
            if sub_query_idx > 0:
                self.logger.info(f"[Sub-query {sub_query_idx + 1}] Resetting environment for new sub-query")
                obs, info = self.env.reset(options={"config_file": config_file})
                current_url = info["page"].url
                state_info: StateInfo = {"observation": obs, "info": info, "current_url": current_url}
                # Clear trajectory and start fresh for new sub-query
                trajectory = [state_info]
                meta_data = {"action_history": [],
                             "response_history": [],
                             "page": self.env.page}
                print("CURRENT: ", current_url)
                
            
            # Process current sub-query
            while True:
                current_url = current_url.lower()
                all_view_url.append(current_url)

                early_stop_flag, stop_info = early_stop(
                    trajectory, self.args.max_steps, {
                        "parsing_failure": self.args.parsing_failure_th,
                        "repeating_action": self.args.repeating_action_failure_th,
                    }
                )

                if early_stop_flag:
                    # Check if fallback is enabled and should be used
                    if self.args.enable_fallback:
                        from utils.fallback_answer import generate_fallback_answer                       
                        self.logger.info(f"[Fallback] Early stop detected: {stop_info}. Generating fallback answer...")
                        
                        # Generate fallback answer using the LLM
                        fallback_result = generate_fallback_answer(
                            question=enhanced_intent,
                            trajectory=trajectory,
                            model=self.agent.llm,
                            num_screenshots=self.args.fallback_screenshots
                        )
                        # Create a stop action with the fallback answer
                        action = create_stop_action(fallback_result['answer'])
                        # Integrate the fallback answer into the trajectory
                        trajectory.append(action)
                        sub_query_answers.append((enhanced_intent, fallback_result['answer']))
                        self.logger.info(f"[Fallback] Generated fallback answer: {fallback_result['answer']}...")
                    else:
                        action = create_stop_action(f"Early stop: {stop_info}")
                else:
                    def gen_action(intent, meta, error_message=None):
                        # if error_message:
                        #     intent = f"{intent}\n[ERROR_FEEDBACK]: {error_message}"
                        action, meta =  self.agent.next_action_custom(
                            trajectory,
                            intent,
                            meta_data=meta,
                            model=self.args.loaded_model,
                            args=self.args,
                            error_message=error_message
                        )
                        return action, meta
                    
                    if self.args.action_check:
                        from utils.action_check import action_self_check
                        action, meta_data = action_self_check(gen_action, enhanced_intent, self.env.page, trajectory, meta_data, max_retries=3, repeat_threshold=self.args.repeating_action_failure_th)
                    else:
                        action, meta_data = gen_action(enhanced_intent, meta_data)
                    
                    tokens_used = self._track_tokens(self.agent.llm)
                    config_tokens += tokens_used
                        
                if isinstance(action, list):
                    trajectory.extend(action)
                else:
                    trajectory.append(action)
                
                action_str = get_action_description(action)
                render_helper.render(
                    action, state_info, meta_data, self.args.render_screenshot
                )
                meta_data["action_history"].append(action_str)
                meta_data["page"] = self.env.page
                
                if isinstance(action, list):
                    last_action_type = action[-1]["action_type"]
                else:
                    last_action_type = action["action_type"]
                if last_action_type in [ActionTypes.STOP, 'finished']:
                    self.logger.info(f"[Sub-query {sub_query_idx + 1}] Completed")
                    
                    # Store the subtask answer if using subtask decomposition
                    if self.args.subtask:
                        # Extract the answer from the stop action
                        if isinstance(action, list):
                            answer = action[-1].get('answer', '')
                        else:
                            answer = action.get('answer', '')
                        
                        # Store the subtask intent and answer
                        sub_query_answers.append((enhanced_intent, answer))
                        self.logger.info(f"[Subtask {sub_query_idx + 1}] Answer stored: {answer[:100]}...")
                    
                    break
                
                obs, _, terminated, _, info, current_url = self.env.step(action, observation=obs)
                # observation, 0.0, done, truncated, info
                print("CURRENT: ", current_url)

                state_info = {"observation": obs, "info": info}
                trajectory.append(state_info)

                if terminated:
                    # add a action place holder
                    trajectory.append(create_stop_action(""))
                    self.logger.info(f"[Sub-query {sub_query_idx + 1}] Terminated")
                    
                    # Store the subtask answer if using subtask decomposition
                    if self.args.subtask:
                        sub_query_answers.append((enhanced_intent, "Task terminated without completion"))
                        self.logger.info(f"[Subtask {sub_query_idx + 1}] Terminated without completion")
                    
                    break
                
                count += 1


        # Store trajectory info
        self.trajActions[config_file] = action_list

        # evaluate the scores
        evaluate_model = load_tool_llm(self.args)
        evaluator = evaluator_router(config_file, evaluate_model)
        score = evaluator(
            trajectory=trajectory,
            config_file=config_file,
            page=self.env.page,
            client=self.env.get_page_client(self.env.page),
        )
        
        last_action = trajectory[-1]
        pred = last_action.get("answer", "")
        reasoning = last_action.get("reasoning", "")
        score = 0.0 if 'Early stop' in pred else score
        self.logger.info(f"[Result] Predicted answer: {pred}\nReasoning: {reasoning}")
        
        self.metrics_dict[config_file] = {
                "config": config_file,
                "success": score,
            }
        self.trajSuccess[config_file] = score
        
        # Define success conditions for each domain type
        success_conditions = {
            'singlehop': lambda: score == 1,
            '2hop': lambda: cnt_ans == 2,
            'multihop': lambda: check_list[cnt_ans] == "end"
        }

        # Special handling for singlehop domain
        if is_domain_type(sub_domain, 'singlehop') and score == 1:
            cnt_ans += 1

        # Evaluate success based on domain type
        for domain_type, check_success in success_conditions.items():
            if is_domain_type(sub_domain, domain_type):
                result = "PASS" if score == 1 else "FAIL"
                self.logger.info(f"[Result] ({result}) {config_file}")
                break
        else:
            raise NotImplementedError(f"Unsupported domain type: {sub_domain}")
        
        self.agent.experience_memory = None
        self.agent.experience_texts, self.agent.experience_images = None, None
        
        # Calculate and log processing time and token usage for this config file
        end_time = time.time()
        processing_time = end_time - start_time
        
        # Store timing and token data
        self.config_processing_times[config_file] = processing_time
        self.config_token_counts[config_file] = config_tokens
        
        # Log only essential agent results (no token/timing details)
        self.logger.info(f"[Result] {config_file} - Success: {score}")
        
        # Save token and timing data incrementally after each config file
        self._save_token_timing_data_incremental()
        
        render_helper.close()
        
        # End conversation for this task if training data collection is enabled
        if hasattr(self.agent, 'training_collector') and self.agent.training_collector:
            from utils.training_data_collector import get_collector
            collector = get_collector()
            if collector and collector.enabled and collector.current_conversation_id:
                # Create conversation summary
                conversation_summary = {
                    "task_id": config_file.split('/')[-1].split('.')[0],
                    "site": site,
                    "sub_domain": sub_domain,
                    "success": score,
                    "final_url": current_url,
                    "task_completed": True,
                    "task_description": intent
                }
                
                # # End the conversation
                # if score == 1:
                if self.args.save_examples_memory:
                    saved_file = collector.end_conversation(conversation_summary, score)
                    if saved_file:
                        self.logger.info(f"Conversation saved: {saved_file}")
    
    def _save_token_timing_data_incremental(self):
        """Save token and timing data incrementally after each config file"""
        # Calculate comprehensive statistics
        total_configs = len(self.config_processing_times)
        total_time = sum(self.config_processing_times.values())
        total_tokens = sum(self.config_token_counts.values())
        
        if total_configs > 0:
            avg_time = total_time / total_configs
            avg_tokens = total_tokens / total_configs
            min_time = min(self.config_processing_times.values())
            max_time = max(self.config_processing_times.values())
            min_tokens = min(self.config_token_counts.values())
            max_tokens = max(self.config_token_counts.values())
        else:
            avg_time = avg_tokens = min_time = max_time = min_tokens = max_tokens = 0
        
        token_timing_data = {
            "total_input_tokens": self.total_input_tokens,
            "config_processing_times": self.config_processing_times,
            "config_token_counts": self.config_token_counts,
            "summary": {
                "total_configs_processed": total_configs,
                "total_processing_time": total_time,
                "average_processing_time": avg_time,
                "min_processing_time": min_time,
                "max_processing_time": max_time,
                "total_tokens": total_tokens,
                "average_tokens_per_config": avg_tokens,
                "min_tokens_per_config": min_tokens,
                "max_tokens_per_config": max_tokens
            }
        }
        
        with open(Path(self.args.result_dir) / "token_timing_data.json", "w") as f:
            json.dump(token_timing_data, f, indent=4)
        
        # Save token and timing data to JSON (no logger output)
        
    def _save_results(self):
        """Save final results"""
        with open(Path(self.args.result_dir) / "scores.json", "w") as f:
            json.dump(self.scores, f, indent=4)
        
        # Save token and timing data
        token_timing_data = {
            "total_input_tokens": self.total_input_tokens,
            "config_processing_times": self.config_processing_times,
            "config_token_counts": self.config_token_counts,
            "summary": {
                "total_configs_processed": len(self.config_processing_times),
                "average_processing_time": sum(self.config_processing_times.values()) / len(self.config_processing_times) if self.config_processing_times else 0,
                "total_processing_time": sum(self.config_processing_times.values()),
                "average_tokens_per_config": sum(self.config_token_counts.values()) / len(self.config_token_counts) if self.config_token_counts else 0
            }
        }
        
        with open(Path(self.args.result_dir) / "token_timing_data.json", "w") as f:
            json.dump(token_timing_data, f, indent=4)
        