import gym
import json
import copy
import asyncio
import aiohttp
import numpy as np
import os
import random
import argparse
from tqdm import tqdm
from web_agent_site.envs import WebAgentTextEnv
from web_agent_site.utils import DEBUG_PROD_SIZE
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import logging
import re 
from datetime import datetime
import time
from datasets import load_dataset

# Set seeds for determinism
SEED = int(time.time())  # Use current time as seed for true randomness
random.seed(SEED)
np.random.seed(SEED)

@dataclass
class TaskState:
    """Tracks the state of a single task trajectory"""
    task_id: int
    persona: str
    env: WebAgentTextEnv  # Add environment instance
    actions: List[str] = field(default_factory=list)
    observations: List[str] = field(default_factory=list)
    rationales: List[str] = field(default_factory=list)
    is_finished: bool = False
    current_round: int = 0
    last_action: Optional[str] = None
    last_observation: Optional[str] = None  # This should be the observation BEFORE the last action

    def add_step(self, action: str, observation: str, rationale: Optional[str] = None):
        """Add a step to the trajectory
        
        Args:
            action: The action that was taken
            observation: The observation AFTER taking the action
            rationale: Optional rationale for the action
        """
        # Store current observation as last_observation before updating
        self.last_action = action
        # last_observation should be the observation that was shown to the agent
        # (the one before taking the action, which is the current last item in observations)
        self.last_observation = self.observations[-1] if self.observations else None
        
        # Add the new observation after the action
        self.actions.append(action)
        self.observations.append(observation)
        if rationale:
            self.rationales.append(rationale)
        self.current_round += 1

    def is_complete(self) -> bool:
        return self.is_finished or self.current_round >= 10

    def reset(self):
        """Reset the environment and update initial observation"""
        obs, _ = self.env.reset()
        obs = clean_obs(obs)
        observation = "\n".join(obs.split('Instruction: \n')[-1].split('\n')[1:])
        self.observations = [observation]
        self.actions = []
        self.rationales = []
        self.is_finished = False
        self.current_round = 0
        self.last_action = None
        self.last_observation = observation  # Initial observation

def clean_obs(obs: str) -> str:
    return obs
    # return obs.replace('[button]', '[').replace('[button_]', ']')

class OpenAITaskGenerator:
    def __init__(
        self,
        api_key: str,
        model: str = "gpt-4o-mini",
        max_tokens: int = 2048,
        temperature: float = 0.7,
        num_tasks: int = 1000,
        output_dir: str = "alice_data_openai",
        exp_name: str = None,
        wait_for_user: bool = False,
        skip_instruction: bool = False,
        skip_reasoning: bool = False,
        skip_jsonl: bool = False,
        skip_trajectory: bool = False,
        feedback_mode: bool = False,
        feedback_json: str = None,
        skill_feedback: bool = False,
        combine_jsonl: str = None,
        resume_dir: str = None
    ):
        self.api_key = api_key
        self.model = model
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.num_tasks = num_tasks
        self.output_dir = output_dir
        self.exp_name = exp_name
        self.wait_for_user = wait_for_user
        self.skip_instruction = skip_instruction
        self.skip_reasoning = skip_reasoning
        self.skip_jsonl = skip_jsonl
        self.skip_trajectory = skip_trajectory
        self.feedback_mode = feedback_mode
        self.feedback_json = feedback_json
        self.feedback_summary = None
        self.skill_feedback = skill_feedback
        self.combine_jsonl = combine_jsonl
        self.resume_dir = resume_dir
        # Load personas
        # with open('persona.json') as f:
        #     self.personas = json.load(f)

        persona_data = load_dataset("proj-persona/PersonaHub", "persona")['train']
        self.personas = [persona['persona'] for persona in persona_data]
            
        # Initialize task states
        self.task_states: Dict[int, TaskState] = {}
        
        # Create output directories
        self.setup_directories()
        
        # System prompt template
        self.system_prompt = '''
In the web environment, your actions are strictly limited to two types:

1. search[keywords]: Use this action only when a "[button] Search [button_]" is present in the current web page content. You must replace "keywords" with any valid search query you want to search.

2. click[HTML Element]: Use this action to click on an HTML Element in the page content. "HTML Element" can be any clickable element in the page represented inside "[button]" and "[button_]", such as an item id, action button, or attributes and options like color or size. Note that the 'HTML Element' 'must' be present in the current page content. Also, do not click the attributes inside the "[clicked button]" and "[clicked button_]", "item name", and "button" iteself (e.g. click[button] is not allowed).

Only use search action when a "[button] Search [button_]" is present in the current web page content and otherwise, use click action (click item id, attributes like color and size, or action button).\n\n'''

        # Generate feedback if JSON is provided
        if self.feedback_json:
            asyncio.run(self.generate_feedback())

    def setup_directories(self):
        """Create necessary directories for storing requests and responses"""
        if self.resume_dir:
            # Use existing directory if resuming
            self.exp_dir = self.resume_dir
            self.request_dir = os.path.join(self.exp_dir, "requests")
            self.response_dir = os.path.join(self.exp_dir, "responses")
            return

        # Create timestamp-based directory name
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if self.exp_name:
            dir_name = f"{timestamp}_{self.exp_name}"
        else:
            dir_name = timestamp
            
        # Create main experiment directory
        self.exp_dir = os.path.join(self.output_dir, dir_name)
        os.makedirs(self.exp_dir, exist_ok=True)
        
        # Create subdirectories for requests and responses
        self.request_dir = os.path.join(self.exp_dir, "requests")
        self.response_dir = os.path.join(self.exp_dir, "responses")
        os.makedirs(self.request_dir, exist_ok=True)
        os.makedirs(self.response_dir, exist_ok=True)
        
        # Save experiment configuration
        config = {
            "timestamp": timestamp,
            "exp_name": self.exp_name,
            "model": self.model,
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
            "num_tasks": self.num_tasks
        }
        with open(os.path.join(self.exp_dir, "config.json"), "w") as f:
            json.dump(config, f, indent=4)

    def get_round_dir(self, round_num: int) -> tuple:
        """Get directory paths for a specific round"""
        round_str = f"round_{round_num:03d}"
        request_dir = os.path.join(self.request_dir, round_str)
        response_dir = os.path.join(self.response_dir, round_str)
        os.makedirs(request_dir, exist_ok=True)
        os.makedirs(response_dir, exist_ok=True)
        return request_dir, response_dir

    def create_request_json(self, task_state: TaskState) -> dict:
        """Create the request JSON for OpenAI API"""
        persona = f'''You are a web-shop-agent that can interact with the webpage by taking actions. You need to buy something that you want at the end. Also, you should adopt the identity of following persona :\n            "{task_state.persona}" \nYou should take actions that are consistent with the persona you have adopted.\n'''
        
        system_msg = f"{persona}"#\n{self.system_prompt}"
        
        # Add feedback summary if available
        if self.feedback_mode and self.feedback_summary:
            from prompt_template import PromptTemplate
            template = PromptTemplate()
            system_msg += template.call_summary_prompt(self.feedback_summary)
        
        if task_state.current_round == 0:
            user_msg = (
                "Now here is the current page content. Read the page carefully. Based on the page content and your persona, "
                + ("and the feedback provided, " if self.feedback_mode and self.feedback_summary else "")
                + "provide any valid action that seems very interesting. "
                "When outputting the action, "
                "please write your action after the prompt 'Action:'."
            )
        else:
            # Use two-history format - only include previous step for description and features clicks
            if task_state.last_action in ["click[description]", "click[features]"]:
                user_msg = (
                    f"In the 'previous' round:\n"
                    f"- The page showed: \n{task_state.last_observation}\n"
                    f"- Given that page, you took action: \n{task_state.last_action}\n\n\n"
                )
            else:
                # Include all previous actions except the initial search
                if len(task_state.actions) > 1:
                    previous_actions = "\n".join(
                        f"{i}. {a}" for i, a in enumerate(task_state.actions[1:], start=2)
                    )
                    # print(previous_actions)
                    # import pdb; pdb.set_trace()
                    user_msg = (
                        f"In the previous rounds, after the initial search, you have taken the following actions:\n"
                        f"{previous_actions}\n\n"
                    )
                else:
                    user_msg = ""

            user_msg += (
                f"Now here is the new page content. Read carefully the page content. You can not do search action anymore. Based on your persona, "
                + ("the feedback provided, " if self.feedback_mode and self.feedback_summary else "")
                + "and the current web page content, "
                f"give a brief thought and provide a valid action. "
            )
            
            # Add condition to not click attributes between [clicked button] [clicked_button_]
            # if "buy now" is in the current observation
            if "buy now" in task_state.observations[-1].lower():
                user_msg += "Do not click the attribute between [clicked button] and [clicked_button_]. "
            
            user_msg += (
                f"When outputting the action, "
                f"please write your action after the prompt 'Action:'."
            )
        

            # Previously, based on the previous page contents and your persona, you have taken the following actions:\n" + "\n".join(
            #             f"{i+1}. {a}" for i, a in enumerate(task_state.actions)
            #         ) + ( Avoid repeating the same action.
            # user_msg = (
            #         f"\nNote that until now, you can take up to {10 - len(task_state.actions)} more actions, and you need to buy something before reaching that limit. \n\nNow here is the new page content. Read the page carefully. Based on the previous actions, your persona, "
            #         + ("the feedback provided, " if self.feedback_mode and self.feedback_summary else "")
            #         + "and the current web page content, "
            #         f"give a brief thought and provide a valid action. When outputting the action, "
            #         f"please write your action after the prompt 'Action:'.")
            
        user_msg = self.system_prompt + "\n\n" + user_msg
        return {
            "model": self.model,
            "messages": [
                {"role": "system", "content": system_msg},
                {"role": "user", "content": f"{user_msg}\n\nPage Content:\n{task_state.observations[-1]}\n\n"}
            ],
            "max_tokens": self.max_tokens,
            "temperature": self.temperature
        }

    async def process_response(self, task_state: TaskState, response: dict) -> bool:
        """Process API response and update task state"""
        try:
            content = response["choices"][0]["message"]["content"]
            
            # Extract action and rationale
            action_match = re.search(r'Action\s*:\s*(search\[[^\]]+\]|click\[[^\]]+\])', content, re.IGNORECASE)
            if not action_match:
                logging.warning(f"Task {task_state.task_id}: Failed to extract action from response, using default action")
                action = "click[back to search]"
            else:
                action = action_match.group(1).lower()
                action = re.sub(r'\[\s+', '[', action)
                action = re.sub(r'\s+\]', ']', action)
            
            # Extract rationale
            rationale_end_idx = content.lower().find('action :')
            if rationale_end_idx == -1:
                rationale_end_idx = content.lower().find('action:')
            rationale = content[:rationale_end_idx].strip() if rationale_end_idx > 0 else "No explicit reasoning provided."
            
            # Check if this is a buy now action after description/features
            is_buy_now_after_detail = (
                action.lower() == 'click[buy now]' and 
                task_state.last_action and 
                task_state.last_action.lower() in ['click[description]', 'click[features]']
            )
            
            if is_buy_now_after_detail:
                # Add click[< prev] action and its observation before buy now
                prev_obs = task_state.last_observation
                task_state.add_step('click[< prev]', prev_obs, "Going back to previous page to complete purchase")
            
            # Take step in environment using task-specific environment
            obs, _, _, _ = task_state.env.step(action)
            obs = clean_obs(obs)
            match = re.search(r'(?i)instruction:\s*\n.*?\n(.*)', obs, re.DOTALL)
            observation = match.group(1).strip() if match else ""
            
            # Update task state
            task_state.add_step(action, observation, rationale)
            
            # Check if task is finished
            if ('click[buy now]' in action.lower() and 'buy now' in task_state.last_observation.lower()) or is_buy_now_after_detail:
                task_state.is_finished = True
                logging.info(f"Task {task_state.task_id} finished successfully after {task_state.current_round} rounds")
                return True
                
            return False
            
        except Exception as e:
            logging.error(f"Error processing response for task {task_state.task_id}: {e}")
            return False

    async def generate_tasks(self):
        """Main function to generate tasks asynchronously"""
        if self.skip_trajectory:
            logging.info("Skipping trajectory generation")
            # Load existing results if available
            results_file = os.path.join(self.exp_dir, "final_results.json")
            if os.path.exists(results_file):
                with open(results_file, 'r') as f:
                    final_results = json.load(f)
            else:
                logging.error("No existing results found to skip trajectory generation")
                return
        else:
            # Initialize task states
            for i in range(self.num_tasks):
                # Set a deterministic seed based on task ID for reproducibility
                # Use a deterministic seed based on task ID and initial seed
                # Avoid using time.time() here as it changes with each call
                # and would make the seed non-deterministic
                # Use current time to make seed different for each run
                current_time = int(time.time())
                task_seed = SEED + i + current_time
                random.seed(task_seed)
                np.random.seed(task_seed)
                persona = random.choice(self.personas)
                # Create a new environment instance for each task
                env = gym.make('WebAgentTextEnv-v0', observation_mode='text_rich', num_products=DEBUG_PROD_SIZE)
                obs, _ = env.reset()
                obs = clean_obs(obs)
                observation = "\n".join(obs.split('Instruction: \n')[-1].split('\n')[1:])
                
                task_state = TaskState(
                    task_id=i,
                    persona=persona,
                    env=env,  # Store environment instance
                    observations=[observation],
                    last_observation=observation
                )
                self.task_states[i] = task_state

            # Print that all task states are initialized
            logging.info(f"Initialized {len(self.task_states)} task states")

            round_num = 0
            pbar = tqdm(total=self.num_tasks, desc="Completed tasks")
            completed_tasks = 0
            
            # Main loop - continue until all tasks are complete
            while any(not state.is_complete() for state in self.task_states.values()) and round_num < 10:
                # Get directories for current round
                request_dir, response_dir = self.get_round_dir(round_num)
                
                # Create requests for unfinished tasks
                requests = []
                active_task_ids = []
                for task_id, state in self.task_states.items():
                    if not state.is_complete():
                        request = self.create_request_json(state)
                        requests.append((task_id, request))
                        active_task_ids.append(task_id)
                
                if not requests:
                    logging.info("No active tasks remaining")
                    break
                    
                logging.info(f"Round {round_num}: Processing {len(requests)} active tasks")
                
                # Save requests
                request_file = os.path.join(request_dir, "requests.jsonl")
                with open(request_file, "w") as f:
                    for task_id, request in requests:
                        f.write(json.dumps({"task_id": task_id, "request": request}) + "\n")
                
                # Process requests asynchronously
                async with aiohttp.ClientSession() as session:
                    tasks = []
                    for task_id, request in requests:
                        task = asyncio.create_task(
                            self.make_api_request(session, task_id, request)
                        )
                        tasks.append(task)
                    
                    # Wait for all requests to complete
                    responses = await asyncio.gather(*tasks)
                    
                    # Process responses and track newly completed tasks
                    newly_completed = 0
                    for task_id, response in responses:
                        if task_id in self.task_states:
                            task_state = self.task_states[task_id]
                            if not task_state.is_complete():
                                await self.process_response(task_state, response)
                                if task_state.is_complete():
                                    newly_completed += 1
                
                # Update progress bar with newly completed tasks
                completed_tasks += newly_completed
                pbar.update(newly_completed)
                
                # Save responses
                response_file = os.path.join(response_dir, "responses.jsonl")
                with open(response_file, "w") as f:
                    for task_id, response in responses:
                        if task_id in self.task_states:
                            state = self.task_states[task_id]
                            data = {
                                "task_id": task_id,
                                "round": round_num,
                                "persona": state.persona,
                                "action": state.actions[-1] if state.actions else None,
                                "observation": state.observations[-1],
                                "rationale": state.rationales[-1] if state.rationales else None,
                                "is_finished": state.is_finished,
                                "current_round": state.current_round,
                                "is_complete": state.is_complete()
                            }
                            f.write(json.dumps(data) + "\n")
                
                round_num += 1
                
                if self.wait_for_user:
                    input(f"Round {round_num-1} completed. Press Enter to continue to next round...")
            
            # Mark any remaining tasks as complete after max rounds
            if round_num >= 10:
                for state in self.task_states.values():
                    if not state.is_complete():
                        state.is_finished = True  # Mark as finished even if not successful
            
            pbar.close()
            logging.info(f"All tasks completed after {round_num} rounds")
            
            # Save final task states before instruction generation
            final_results = []
            for task_id, state in self.task_states.items():
                result = {
                    "task_id": task_id,
                    "persona": state.persona,
                    "actions": state.actions,
                    "observations": state.observations,
                    "rationales": state.rationales,
                    "is_finished": state.is_finished,
                    "rounds_taken": state.current_round
                }
                final_results.append(result)
            
            # Save final results
            with open(os.path.join(self.exp_dir, "final_results.json"), "w") as f:
                json.dump(final_results, f, indent=4)
            
            # Create filtered results
            filtered_results = []
            for result in final_results:
                filtered_result = self.filter_trajectory_from_last_search(result)
                filtered_results.append(filtered_result)
                
            # Save filtered results
            with open(os.path.join(self.exp_dir, "final_results_filtered.json"), "w") as f:
                json.dump(filtered_results, f, indent=4)
        
        # Process with OpenAI API based on command line flags
        if not self.skip_instruction:
            # Generate instructions for successful tasks
            logging.info("Generating instructions for successful tasks")
            await self.generate_instructions(final_results)
            await self.generate_instructions(filtered_results, is_filtered=True)
            
            # Save updated final results with instructions
            with open(os.path.join(self.exp_dir, "final_results_with_instructions.json"), "w") as f:
                json.dump(final_results, f, indent=4)
            with open(os.path.join(self.exp_dir, "final_results_filtered_with_instructions.json"), "w") as f:
                json.dump(filtered_results, f, indent=4)
        else:
            logging.info("Skipping instruction generation")

        if not self.skip_reasoning:
            # read final results from json file
            with open(os.path.join(self.exp_dir, "final_results_with_instructions.json"), "r") as f:
                final_results = json.load(f)
            with open(os.path.join(self.exp_dir, "final_results_filtered_with_instructions.json"), "r") as f:
                filtered_results = json.load(f)
            # Generate post-hoc reasoning for successful tasks
            logging.info("Generating post-hoc reasoning for successful tasks")
            await self.add_post_hoc_reasoning(final_results)
            await self.add_post_hoc_reasoning(filtered_results, is_filtered=True)
            
            # Save updated final results with reasoning
            with open(os.path.join(self.exp_dir, "final_results_with_reasoning.json"), "w") as f:
                json.dump(final_results, f, indent=4)
            with open(os.path.join(self.exp_dir, "final_results_filtered_with_reasoning.json"), "w") as f:
                json.dump(filtered_results, f, indent=4)
        else:
            logging.info("Skipping reasoning generation")
            
        # Always save final results
        with open(os.path.join(self.exp_dir, "final_results.json"), "w") as f:
            json.dump(final_results, f, indent=4)
            
        if not self.skip_jsonl:
            # Convert to JSONL for training
            logging.info("Creating training JSONL files")
            self.convert_json_into_jsonl_with_rationale(final_results)
            self.convert_json_into_jsonl_with_rationale(filtered_results, is_filtered=True)
            
            # Combine JSONL if requested
            if self.combine_jsonl and os.path.exists(self.combine_jsonl):
                logging.info(f"Combining with existing JSONL file: {self.combine_jsonl}")
                
                # Combine regular JSONL
                combined_path = os.path.join(self.exp_dir, "train_w_as_action_combined.jsonl")
                original_path = os.path.join(self.exp_dir, "train_w_as_action.jsonl")
                
                with open(combined_path, 'w') as outfile:
                    # Copy original file
                    with open(original_path, 'r') as infile:
                        for line in infile:
                            outfile.write(line)
                    
                    # Append combine_jsonl file
                    with open(self.combine_jsonl, 'r') as infile:
                        for line in infile:
                            outfile.write(line)
                            
                logging.info(f"Created combined JSONL file: {combined_path}")
                
                # Combine filtered JSONL
                filtered_combined_path = os.path.join(self.exp_dir, "train_w_as_action_filtered_combined.jsonl")
                filtered_original_path = os.path.join(self.exp_dir, "train_w_as_action_filtered.jsonl")
                
                with open(filtered_combined_path, 'w') as outfile:
                    # Copy filtered original file
                    with open(filtered_original_path, 'r') as infile:
                        for line in infile:
                            outfile.write(line)
                    
                    # Append combine_jsonl file
                    with open(self.combine_jsonl, 'r') as infile:
                        for line in infile:
                            outfile.write(line)
                            
                logging.info(f"Created combined filtered JSONL file: {filtered_combined_path}")
        else:
            logging.info("Skipping JSONL conversion")
            
        logging.info("All processing completed successfully!")

    async def generate_instructions(self, results, is_filtered=False):
        """Generate instructions using OpenAI API with async
        
        Args:
            results: List of task results
            is_filtered: Whether these are filtered results
        """
        # Create instruction directories
        suffix = "_filtered" if is_filtered else ""
        instruction_request_dir = os.path.join(self.request_dir, f"instructions{suffix}")
        instruction_response_dir = os.path.join(self.response_dir, f"instructions{suffix}")
        os.makedirs(instruction_request_dir, exist_ok=True)
        os.makedirs(instruction_response_dir, exist_ok=True)
        
        # Filter successful tasks
        successful_tasks = [
            task for task in results
            if task["is_finished"] and "click[buy now]" in task["actions"][-1].lower()
        ]
        
        if not successful_tasks:
            logging.warning(f"No successful tasks found for {'filtered ' if is_filtered else ''}instruction generation")
            return
        
        logging.info(f"Generating {'filtered ' if is_filtered else ''}instructions for {len(successful_tasks)} successful tasks")
        
        # Define system message
        system_msg = (
            "You are a helpful assistant trained to understand web environment and "
            "generate shopping instructions. You are given an action sequence and a "
            "final product description. Your task is to generate only an user query that will "
            "lead to the final product description."
        )
        
        # Prepare requests for successful tasks
        requests = []
        for task in successful_tasks:
            task_id = task["task_id"]
            action_sequence = task["actions"]
            final_state = task["observations"][-1]
            
            user_msg = (
                f"Now here are the given action sequence and final product description.\n\n"
                f"Action Sequence:\n{action_sequence}\n\n"
                f"Final Product Description:\n{final_state}\n\n"
                f"Considering both search keywords and product detail, please generate an user query.\n"
                f"Please put more weight on the search keywords than the product detail. Do not directly include the product name in the query and rather give a high-level description of the product.\n"
                f"Note that clicked attributes in action sequence, like size, color, and options should be included in the query. (Buy now is not an attribute)\n"
                f"Attributes without [clicked button] should not be included in the query, as they are not part of the product.\n"
                f"You should also include the price condition in the query (e.g. price lower than XX dollars).\n"
                f"You should not include any other text than the query. Randomly start the query with words 'Find me','Show me', 'I am looking for', 'I need', 'I want', or similar words.\n\n"
                f"User Query: "
            )
            
            request = {
                "model": self.model,
                "messages": [
                    {"role": "system", "content": system_msg},
                    {"role": "user", "content": user_msg}
                ],
                "max_tokens": self.max_tokens,
                "temperature": self.temperature
            }
            
            requests.append((task_id, request))
        
        # Save instruction requests
        request_file = os.path.join(instruction_request_dir, f"instruction_requests{suffix}.jsonl")
        with open(request_file, "w") as f:
            for task_id, request in requests:
                f.write(json.dumps({"task_id": task_id, "request": request}) + "\n")
        
        # Process requests asynchronously
        async with aiohttp.ClientSession() as session:
            tasks = []
            for task_id, request in requests:
                task = asyncio.create_task(
                    self.make_api_request(session, task_id, request)
                )
                tasks.append(task)
            
            # Wait for all requests to complete
            responses = await asyncio.gather(*tasks)
        
        # Process and save responses
        response_file = os.path.join(instruction_response_dir, f"instruction_responses{suffix}.jsonl")
        with open(response_file, "w") as f:
            for task_id, response in responses:
                content = response["choices"][0]["message"]["content"]
                instruction = content.strip()
                
                # Find the matching task in results
                for task in results:
                    if task["task_id"] == task_id:
                        task["instruction"] = instruction
                        break
                
                f.write(json.dumps({"task_id": task_id, "instruction": instruction}) + "\n")

    async def make_api_request(self, session: aiohttp.ClientSession, task_id: int, request: dict) -> tuple:
        """Make a single API request"""
        try:
            async with session.post(
                "https://api.openai.com/v1/chat/completions",
                headers={"Authorization": f"Bearer {self.api_key}"},
                json=request,
                timeout=30  # Add timeout to prevent hanging requests
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    logging.error(f"API error for task {task_id}: {response.status} - {error_text}")
                    return task_id, {"choices": [{"message": {"content": f"Error: API request failed with status {response.status}"}}]}
                    
                result = await response.json()
                return task_id, result
        except asyncio.TimeoutError:
            logging.error(f"Timeout for API request for task {task_id}")
            return task_id, {"choices": [{"message": {"content": "Error: API request timed out"}}]}
        except Exception as e:
            logging.error(f"Error making API request for task {task_id}: {e}")
            return task_id, {"choices": [{"message": {"content": f"Error: API request failed: {str(e)}"}}]}

    async def add_post_hoc_reasoning(self, results, is_filtered=False):
        """Generate post-hoc reasoning for each observation-action pair using OpenAI API
        
        Args:
            results: List of task results
            is_filtered: Whether these are filtered results
        """
        # Create reasoning directories
        suffix = "_filtered" if is_filtered else ""
        reasoning_request_dir = os.path.join(self.request_dir, f"reasoning{suffix}")
        reasoning_response_dir = os.path.join(self.response_dir, f"reasoning{suffix}")
        os.makedirs(reasoning_request_dir, exist_ok=True)
        os.makedirs(reasoning_response_dir, exist_ok=True)
        
        # Filter successful tasks that have instructions
        successful_tasks = [
            task for task in results
            if task["is_finished"] and 
            "click[buy now]" in task["actions"][-1].lower() and
            "instruction" in task
        ]
        
        if not successful_tasks:
            logging.warning(f"No successful tasks with instructions found for {'filtered ' if is_filtered else ''}reasoning generation")
            return
        
        logging.info(f"Generating {'filtered ' if is_filtered else ''}post-hoc reasoning for {len(successful_tasks)} successful tasks")
        
        # Define system message for reasoning
        system_msg = (
            "You are an AI assistant tasked with explaining actions taken in a web environment. "
        )
        
        # Process each task
        for task in tqdm(successful_tasks, desc="Generating reasoning"):
            task_id = task["task_id"]
            instruction = task["instruction"]
            observations = task["observations"]
            actions = task["actions"]
            
            # Task needs a reasoning for each observation-action pair
            reasoning_requests = []
            for i in range(len(actions)):
                current_observation = observations[i]  # Observation before the action was taken
                action = actions[i]
                previous_actions = actions[:i]
                
                user_msg = (
                    "Given the instruction you need to follow and the current observation, provide a rationale for why the 'last action' was taken to follow the instruction. "
                    "You can also refer to the previous actions to provide a rationale. "
                    "The rationale should naturally fit with '[your rationale]. Thus, my action is [chosen action].' "
                    "You only need to provide 'your rationale' part. Be very concise and clear. "
                    f"Now, here are the given instruction, previous actions, current observation, and the last action.\n\n"
                    f"Instruction: {instruction}\n\n"
                    f"Previous actions before the last action: {previous_actions}\n\n"
                    f"Current observation: {current_observation}\n\n"
                    f"Last action taken based on the current observation: {action}\n\n"
                    "Why was this last action taken? Provide a rationale:"
                )
                
                request = {
                    "model": "gpt-4o-mini",
                    "messages": [
                        {"role": "system", "content": system_msg},
                        {"role": "user", "content": user_msg}
                    ],
                    "max_tokens": self.max_tokens,
                    "temperature": self.temperature
                }
                
                reasoning_requests.append((task_id, i, request))
            
            # Save reasoning requests for this task
            request_file = os.path.join(reasoning_request_dir, f"reasoning_requests_task_{task_id}{suffix}.jsonl")
            with open(request_file, "w") as f:
                for task_id, step_idx, request in reasoning_requests:
                    f.write(json.dumps({"task_id": task_id, "step": step_idx, "request": request}) + "\n")
            
            # Process requests for this task asynchronously
            async with aiohttp.ClientSession() as session:
                reasoning_tasks = []
                for task_id, step_idx, request in reasoning_requests:
                    reasoning_task = asyncio.create_task(
                        self.make_api_request(session, (task_id, step_idx), request)
                    )
                    reasoning_tasks.append(reasoning_task)
                
                # Wait for all requests to complete
                reasoning_responses = await asyncio.gather(*reasoning_tasks)
            
            # Process and save responses
            response_file = os.path.join(reasoning_response_dir, f"reasoning_responses_task_{task_id}{suffix}.jsonl")
            
            # Initialize an empty list for reasonings if it doesn't exist
            if "reasonings" not in task:
                task["reasonings"] = [None] * len(actions)
            
            with open(response_file, "w") as f:
                for (task_id, step_idx), response in reasoning_responses:
                    content = response["choices"][0]["message"]["content"]
                    reasoning = content.strip()
                    
                    # Store the reasoning in the task
                    task["reasonings"][step_idx] = reasoning
                    
                    f.write(json.dumps({
                        "task_id": task_id, 
                        "step": step_idx, 
                        "reasoning": reasoning
                    }) + "\n")

    def convert_json_into_jsonl_with_rationale(self, results, is_filtered=False):
        """Convert the task results into JSONL format for training with reasonings
        
        Args:
            results: List of task results
            is_filtered: Whether these are filtered results
        """
        # Filter successful tasks with instructions and reasonings
        successful_tasks = [
            task for task in results
            if task["is_finished"] and
            "click[buy now]" in task["actions"][-1].lower() and
            "instruction" in task and
            "reasonings" in task and
            all(r is not None for r in task["reasonings"])
        ]
        
        if not successful_tasks:
            logging.warning("No complete tasks with instructions and reasonings found for JSONL conversion")
            return
        
        logging.info(f"Converting {len(successful_tasks)} tasks to JSONL format")
        
        # Define shared messages
        shared_system_msg = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

        1. search[keywords]: Use this action only when a "[button] Search [button_]" is present in the current web page content. You must replace "keywords" with any valid search query you want to search.

2. click[HTML Element]: Use this action to click on an HTML Element in the page content. "HTML Element" can be any clickable element in the page represented inside "[button]" and "[button_]", such as an item id, action button, or attributes and options like color or size. Note that the 'HTML Element' 'must' be present in the current page content. Also, do not click the "clicked button" or "item name".

Only use search action when a "[button] Search [button_]" is present in the current web page content and otherwise, use click action (click item id, attributes like color and size, or action button). Now, here is the task:\n\nTask: \n'''

        shared_user_msg = (
            "Now here is the new page content. Based on the given task and the current web page content, give a concise reasoning and then provide a valid action. When outputting the action, "
            "please write your action after the prompt 'Action:'."
        )

        shared_second_user_msg = (
            "Now here is the new page content. Read carefully the page content. Based on the previous actions, the given task, and the current web page content, give a concise reasoning and then provide a valid action. When outputting the action, "
            "please write your action after the prompt 'Action:'."
        )
        
        # Prepare output paths with filtered suffix if needed
        suffix = "_filtered" if is_filtered else ""
        jsonl_output_path = os.path.join(self.exp_dir, f"train{suffix}.jsonl")
        jsonl_output_path_w_as = os.path.join(self.exp_dir, f"train_w_as_action{suffix}.jsonl")
        
        # Generate JSONL without action sequence history
        with open(jsonl_output_path, 'w') as f:
            for task in successful_tasks:
                task_id = task["task_id"]
                persona = task["persona"]
                instruction = task["instruction"]
                
                for i, (observation, action, reasoning) in enumerate(zip(task["observations"][:-1], task["actions"], task["reasonings"])):
                    messages = []
                    system_msg_instance = f"{shared_system_msg}{instruction}\n"
                    user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{observation}\n\n"
                    assistant_msg_instance = (
                        f"{reasoning}\n"
                        f"Action: {action}"
                    )
                    
                    messages.append({'role': 'system', 'content': system_msg_instance})
                    messages.append({'role': 'user', 'content': user_msg_instance})
                    messages.append({'role': 'assistant', 'content': assistant_msg_instance})
                    
                    f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'persona': persona, 'messages': messages}) + '\n')
        
        # Generate JSONL with action sequence history
        with open(jsonl_output_path_w_as, 'w') as f:
            for task in successful_tasks:
                task_id = task["task_id"]
                persona = task["persona"]
                instruction = task["instruction"]
                
                for i, (observation, action, reasoning) in enumerate(zip(task["observations"][:-1], task["actions"], task["reasonings"])):
                    messages = []
                    system_msg_instance = f"{shared_system_msg}{instruction}\n"
                    messages.append({'role': 'system', 'content': system_msg_instance})
                    
                    # Use action history
                    previous_actions = task["actions"][:i]
                    if previous_actions:
                        action_summary = "To complete the given task, you have taken the following actions:\n" + "\n".join(
                            f"{idx+1}. {a}" for idx, a in enumerate(previous_actions)
                        )
                        user_instruction = (
                            f"{action_summary}\nDo not repeat the previous actions.\n\n"
                            f"{shared_second_user_msg}"
                        )
                    else:
                        user_instruction = shared_user_msg
                    
                    user_msg_instance = f"{user_instruction}\n\nPage Content:\n{observation}\n\n"
                    assistant_msg_instance = (
                        f"{reasoning}\nAction: {action}"
                    )
                    
                    messages.append({'role': 'user', 'content': user_msg_instance})
                    messages.append({'role': 'assistant', 'content': assistant_msg_instance})
                    
                    f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'persona': persona, 'messages': messages}) + '\n')
        
        logging.info(f"JSONL files created: {jsonl_output_path} and {jsonl_output_path_w_as}")

    async def generate_feedback(self):
        """Generate model performance feedback based on previous trajectories
        
        Analyzes the trajectories in the feedback_json file to identify strengths and weaknesses.
        Generates concise feedback on what the model is good at and what skills need improvement.
        If skill_feedback is enabled, focuses on predefined skills instead of open-ended feedback.
        """
        if not self.feedback_json or not os.path.exists(self.feedback_json):
            logging.warning(f"Feedback JSON file not found: {self.feedback_json}")
            return
        
        try:
            # Load the feedback data
            with open(self.feedback_json, 'r') as f:
                feedback_data = json.load(f)
            
            rewards = feedback_data.get("rewards", [])
            instructions = feedback_data.get("instructions", [])
            action_sequences = feedback_data.get("action_sequence", [])
            obs_sequences = feedback_data.get("obs_sequence", [])
            
            if not rewards or not instructions or not action_sequences or not obs_sequences:
                logging.warning("Missing required fields in feedback JSON")
                return
            
            # Create feedback directories
            feedback_request_dir = os.path.join(self.request_dir, "feedback")
            feedback_response_dir = os.path.join(self.response_dir, "feedback")
            os.makedirs(feedback_request_dir, exist_ok=True)
            os.makedirs(feedback_response_dir, exist_ok=True)
            
            # Pair rewards with actions and instructions
            trajectory_data = []
            for i in range(len(rewards)):
                if i < len(instructions) and i < len(action_sequences) and i < len(obs_sequences):
                    trajectory_data.append({
                        "reward": rewards[i],
                        "instruction": instructions[i],
                        "actions": action_sequences[i],
                        "observations": obs_sequences[i]
                    })
            
            # Separate trajectories with zero and non-zero rewards
            non_zero_reward_trajectories = [t for t in trajectory_data if t["reward"] > 0]
            zero_reward_trajectories = [t for t in trajectory_data if t["reward"] == 0]
            
            # Sort non-zero reward trajectories by reward (ascending for lowest, descending for highest)
            non_zero_reward_trajectories.sort(key=lambda x: x["reward"])
            
            # Get 4 random trajectories with reward < 0.5
            low_reward_trajectories = [t for t in non_zero_reward_trajectories if t["reward"] < 0.5]
            # Randomly select 4 trajectories (or fewer if not enough available)
            import random
            import time
            random.seed(time.time())
            low_non_zero_trajectories = random.sample(low_reward_trajectories, min(5, len(low_reward_trajectories)))
            
            # Get 2 trajectories with 0 reward
            zero_reward_trajectories = zero_reward_trajectories[:2]
            
            # Get 10 highest reward trajectories (successful ones)
            successful_trajectories = sorted(non_zero_reward_trajectories, key=lambda x: x["reward"], reverse=True)[:2]
            
            # Combine for our failed trajectories
            failed_trajectories = low_non_zero_trajectories + zero_reward_trajectories
            
            # Create system message for feedback analysis
            if self.skill_feedback:
                predefined_skills = [
                    "Avoiding action repetition (not stuck in the middle)",
                    "Performing valid actions",
                    "Creating detailed search queries",
                    "Selecting appropriate product attributes",
                    "Completing purchases",
                    "Making informed product selections",
                    # "Recovering from failures (backtracking)"
                ]

                predefined_guides = [
                    "Avoid repeating the same action in consecutive steps",
                    "Only use search when the search button is present, and click only on elements that exist on the current page",
                    "Use very detailed search queries (more than 5~8 words)",
                    "Try to click many product attributes (size, color, etc.)",
                    "Always aim to complete the shopping task by clicking the 'Buy Now' button when you find a suitable product",
                    "Explore multiple products before making a selection, considering their attributes and your preferences",
                    # "Use navigation buttons (prev/next) strategically to explore different options and recover from incorrect selections"
                ]
                
                system_msg = (
                    "You are an AI assistant tasked with analyzing web shopping trajectories. "
                    "To get a high reward, the model needs to complete the task with the given instruction, fulfilling the task requirements such as price, clicking on the right product, correctly clicking the attributes like size and color, etc."
                    "Given trajectories of varying rewards, your task is to identify which skills the current model lacks the most:\n\n"
                    "Using your feedback, you will explore the web shopping task on the next round, where your trajectories will be used to train the model."
                     "Please note that during your exploration, you don't have an instruction to follow and just explore the web shopping task with your feedback."
                )
                # such as finding right price using detailed search queries, clicking on the right product, cliking on the correct attributes like size and color, etc
            else:
                system_msg = (
                    "You are an AI assistant tasked with analyzing web shopping trajectories. "
                    "To get a high reward, the model needs to complete the task with the given instruction, fulfilling the task requirements of product type, price, attributes like size and color, etc."
                    "Given trajectories of varying rewards, identify strengths in successful trajectories and weaknesses "
                    "in failed trajectories. Provide concise feedback (2 points maximum) on "
                    "what skills need improvement to achieve a high reward."
                    "Using your feedback, you will explore the web shopping task on the next round, where your trajectories will be used to train the model."
                    "For example, if the model lacks detailed search queries, you need to make an initial query very detailed when the search page is shown as your search will be used as the data for fine-tuning the model."
                )
            
            # Create user message with trajectory examples
            user_msg = (
                "Now here are the trajectories of the current model:\n\n"
            )
            
            # Add successful trajectories
            user_msg += "\nSuccessful trajectories (high reward):\n"
            for i, traj in enumerate(successful_trajectories):
                user_msg += f"\nTrajectory {i+1}:\n"
                user_msg += f"Instruction: {traj['instruction']}\n"
                user_msg += f"Observation-Action Sequence:\n"
                
                # Create observation-action pairs
                for j in range(min(len(traj['observations']), len(traj['actions']))):
                    user_msg += f"  Observation {j+1}: {traj['observations'][j][:200]}... (truncated)\n"
                    user_msg += f"  Action {j+1}: {traj['actions'][j]}\n\n"
                    
                user_msg += f"Reward: {traj['reward']}\n"
            
            # Add failed trajectories with observation-action sequences
            user_msg += "\nFailed trajectories (low or zero reward):\n"
            for i, traj in enumerate(failed_trajectories):
                user_msg += f"\nTrajectory {i+1}:\n"
                user_msg += f"Instruction: {traj['instruction']}\n"
                user_msg += f"Observation-Action Sequence:\n"
                
                # Create observation-action pairs
                for j in range(min(len(traj['observations']), len(traj['actions']))):
                    user_msg += f"  Observation {j+1}: {traj['observations'][j][:200]}... (truncated)\n"
                    user_msg += f"  Action {j+1}: {traj['actions'][j]}\n\n"
                    
                user_msg += f"Reward: {traj['reward']}\n"
            
            
            if self.skill_feedback:
                user_msg += (
                    "-"*50 + "\n"
                    "\nBased on these trajectories, analyze which of the following predefined skills the model lacks the most:\n\n"
                    + "\n".join(f"{i+1}. {skill}\n   Guide: {guide}" for i, (skill, guide) in enumerate(zip(predefined_skills, predefined_guides))) + "\n\n"
                    "Select the top 2 skills that the model needs to improve the most, based on the evidence in the trajectories. "
                    "Do not choose skills at random - only select skills where there is clear evidence of deficiency. "
                    "For each selected skill, use a corresponding predefined guide on what you need to do during your exploration of the web shopping task. "
                    "The example format could be like this: \n"
                    "1. Current Model Lacks [Name of skill from the predefined list]\n"
                    "   Exploration guide: [What you will do to improve this skill during exploration]\n\n"
                )
            else:
                user_msg += (
                    "-"*50 + "\n"
                    "\nBased on these trajectories, provide concise feedback (2 points maximum) on what kind of behaviours are desirable and undesirable during exploration. "
                    "Keep the points very brief. "
                    "Most importantly, for each point, write a brief guide on what you need to do during your exploration of the web shopping task on the next round. "
                    "Also, you can take up to 10 actions in the environment, so please give feedback on how to have a good and concise action sequence."
                    "*****Note that during your exploration, there are 'no instructions, given criteria, or requirements to follow', so you need to provide feedback on which types of actions are beneficial (as there are only two types: search and click, specify which search keywords or clicking on which elements are beneficial). If you do certain actions with your interest, the models are encourage to do more of that action. Thus, do not say something like 'do something to meet criteria', 'follow the criteria, instructions, or given states', or 'match specific attributes'. Just say what you think is good or bad.\n"
                    "The example format could be like this: \n"
                    "1. The current low reward is due to B. Refrain from B during your exploration.\n\n"
                    "2. The current low reward is due to not clicking C. Ensure to do click diverse C during your exploration.\n\n"
                )
            
            # Create request
            request = {
                "model": "gpt-4o",
                "messages": [
                    {"role": "system", "content": system_msg},
                    {"role": "user", "content": user_msg}
                ],
                "max_tokens": self.max_tokens,
                "temperature": 0.5
            }
            
            # Save feedback request
            request_file = os.path.join(feedback_request_dir, "feedback_request.json")
            with open(request_file, "w") as f:
                json.dump(request, f, indent=4)
            
            # Make API request
            async with aiohttp.ClientSession() as session:
                async with session.post(
                    "https://api.openai.com/v1/chat/completions",
                    headers={"Authorization": f"Bearer {self.api_key}"},
                    json=request,
                    timeout=60
                ) as response:
                    if response.status != 200:
                        error_text = await response.text()
                        logging.error(f"Feedback API error: {response.status} - {error_text}")
                        return
                        
                    result = await response.json()

                    # Extract feedback
                    feedback_content = result["choices"][0]["message"]
                    ["content"]

                    # import pdb; pdb.set_trace()
                    
                    # Save feedback response
                    response_file = os.path.join(feedback_response_dir, "feedback_response.json")
                    with open(response_file, "w") as f:
                        json.dump(result, f, indent=4)
                    
                    # Save feedback summary
                    summary_file = os.path.join(self.exp_dir, "feedback_summary.txt")
                    with open(summary_file, "w") as f:
                        f.write(feedback_content)
                    
                    # Store feedback summary
                    self.feedback_summary = feedback_content
                    logging.info("Generated feedback summary")
                    
                    # import pdb; pdb.set_trace()

                    
                    
                    return feedback_content
                    
        except Exception as e:
            logging.error(f"Error generating feedback: {e}")
            return None

    def filter_trajectory_from_last_search(self, task):
        """Filter trajectory starting from the last search action
        
        Args:
            task: Task dictionary containing actions and observations
            
        Returns:
            Filtered task dictionary with trimmed actions and observations
        """
        filtered_task = copy.deepcopy(task)
        
        # Find the last search action
        last_search_idx = -1
        for i, action in enumerate(task["actions"]):
            if action.startswith("search["):
                last_search_idx = i
                
        if last_search_idx >= 0:
            # Trim actions and observations from last search
            filtered_task["actions"] = task["actions"][last_search_idx:]
            filtered_task["observations"] = task["observations"][last_search_idx:last_search_idx + len(filtered_task["actions"]) + 1]
            if "reasonings" in task:
                filtered_task["reasonings"] = task["reasonings"][last_search_idx:]
        
        return filtered_task

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--api-key", type=str, required=True, help="OpenAI API key")
    parser.add_argument("--model", type=str, default="gpt-4o-mini", help="OpenAI model to use")
    parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens per response")
    parser.add_argument("--temperature", type=float, default=0.4, help="Temperature for generation")
    parser.add_argument("--num-tasks", type=int, default=5, help="Number of tasks to generate")
    parser.add_argument("--output-dir", type=str, default="output_data", help="Output directory")
    parser.add_argument("--exp-name", type=str, default=None, help="Experiment name to append to timestamp")
    parser.add_argument("--wait-for-user", action="store_true", help="Wait for user input between rounds")
    parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
    parser.add_argument("--skip-instruction", action="store_true", help="Skip instruction generation")
    parser.add_argument("--skip-reasoning", action="store_true", help="Skip reasoning generation")
    parser.add_argument("--skip-jsonl", action="store_true", help="Skip JSONL conversion")
    parser.add_argument("--skip-trajectory", action="store_true", help="Skip trajectory generation")
    parser.add_argument("--feedback-mode", action="store_true", help="Enable feedback mode")
    parser.add_argument("--feedback-json", type=str, help="Path to feedback JSON file")
    parser.add_argument("--skill", action="store_true", help="Use predefined skills for feedback instead of open-ended feedback")
    parser.add_argument("--combine-jsonl", type=str, help="Path to JSONL file to combine with train_w_as_action.jsonl")
    parser.add_argument("--resume-dir", type=str, help="Path to directory to resume from")
    parser.add_argument("--persona-path", type=str, default="persona.json", help="Path to persona data")
    args = parser.parse_args()

    # Setup logging
    numeric_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError(f'Invalid log level: {args.log_level}')
    logging.basicConfig(
        level=numeric_level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(f"alice_openai_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
        ]
    )

    generator = OpenAITaskGenerator(
        api_key=args.api_key,
        model=args.model,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        num_tasks=args.num_tasks,
        output_dir=args.output_dir,
        exp_name=args.exp_name,
        wait_for_user=args.wait_for_user,
        skip_instruction=args.skip_instruction,
        skip_reasoning=args.skip_reasoning,
        skip_jsonl=args.skip_jsonl,
        skip_trajectory=args.skip_trajectory,
        feedback_mode=args.feedback_mode,
        feedback_json=args.feedback_json,
        skill_feedback=args.skill,
        combine_jsonl=args.combine_jsonl,
        resume_dir=args.resume_dir
    )

    asyncio.run(generator.generate_tasks())

if __name__ == "__main__":
    main() 