import gym
import json
import asyncio
import aiohttp
import numpy as np
import os
import random
import argparse
from tqdm import tqdm
from web_agent_site.envs import WebAgentTextEnv
from web_agent_site.utils import DEBUG_PROD_SIZE
from dataclasses import dataclass, field
from typing import List, Dict, Optional
import logging
import re
from datetime import datetime
import time
from datasets import load_dataset

# Set seeds for determinism
SEED = int(time.time())
random.seed(SEED)
np.random.seed(SEED)

@dataclass
class TaskState:
    """Tracks the state of a single task trajectory"""
    task_id: int
    instruction: str
    persona: str
    env: WebAgentTextEnv
    actions: List[str] = field(default_factory=list)
    observations: List[str] = field(default_factory=list)
    rationales: List[str] = field(default_factory=list)
    is_finished: bool = False
    current_round: int = 0
    last_action: Optional[str] = None
    last_observation: Optional[str] = None

    def add_step(self, action: str, observation: str, rationale: Optional[str] = None):
        """Add a step to the trajectory"""
        self.last_action = action
        self.last_observation = self.observations[-1] if self.observations else None
        
        self.actions.append(action)
        self.observations.append(observation)
        if rationale:
            self.rationales.append(rationale)
        self.current_round += 1

    def is_complete(self) -> bool:
        return self.is_finished or self.current_round >= 10

    def reset(self):
        """Reset the environment and update initial observation"""
        obs, _ = self.env.reset()
        obs = clean_obs(obs)
        observation = "\n".join(obs.split('Instruction: \n')[-1].split('\n')[1:])
        self.observations = [observation]
        self.actions = []
        self.rationales = []
        self.is_finished = False
        self.current_round = 0
        self.last_action = None
        self.last_observation = observation

def clean_obs(obs: str) -> str:
    return obs

class OpenAITaskProposer:
    def __init__(
        self,
        api_key: str,
        model: str = "gpt-4",
        max_tokens: int = 2048,
        temperature: float = 0.7,
        num_tasks: int = 100,
        output_dir: str = "alice_data_proposed",
        exp_name: str = None,
        skip_jsonl: bool = False
    ):
        self.api_key = api_key
        self.model = model
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.num_tasks = num_tasks
        self.output_dir = output_dir
        self.exp_name = exp_name
        self.skip_jsonl = skip_jsonl
        
        # Initialize task states
        self.task_states: Dict[int, TaskState] = {}
        
        # Load personas from PersonaHub dataset
        persona_data = load_dataset("proj-persona/PersonaHub", "persona")['train']
        self.personas = [persona['persona'] for persona in persona_data]
        
        # Create output directories
        self.setup_directories()
        
        # System prompt for task proposal
        self.proposal_system_prompt = '''You are an AI assistant tasked with proposing interesting shopping tasks. 
You need to adopt the identity of a given persona and generate a specific shopping instruction that would be interesting to execute in a web shopping environment.

The instruction should:

1. Be realistic and achievable in a web shopping environment
2. Be interesting and varied
3. Be consistent with the persona's characteristics and preferences

Generate only the instruction text, starting with "I want to buy" or similar phrases.'''

        # System prompt for task execution
        self.execution_system_prompt = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

You need to adopt the identity of a given persona and take actions that are consistent with that persona's characteristics and preferences.

In each step, your actions are strictly limited to two types:

1. search[keywords]: Use this action only when a "[button] Search [button_]" is present in the current web page content. You must replace "keywords" with any valid search query you want to search.

2. click[HTML Element]: Use this action to click on an HTML Element in the page content. "HTML Element" can be any clickable element in the page represented inside "[button]" and "[button_]", such as an item id, action button, or attributes and options like color or size. Note that the 'HTML Element' 'must' be present in the current page content. Also, do not click the "clicked button" or "item name".

Only use search action when a "[button] Search [button_]" is present in the current web page content and otherwise, use click action (click item id, attributes like color and size, or action button).'''

    def setup_directories(self):
        """Create necessary directories for storing requests and responses"""
        # Create timestamp-based directory name
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        if self.exp_name:
            dir_name = f"{timestamp}_{self.exp_name}"
        else:
            dir_name = timestamp
            
        # Create main experiment directory
        self.exp_dir = os.path.join(self.output_dir, dir_name)
        os.makedirs(self.exp_dir, exist_ok=True)
        
        # Create subdirectories for requests and responses
        self.request_dir = os.path.join(self.exp_dir, "requests")
        self.response_dir = os.path.join(self.exp_dir, "responses")
        os.makedirs(self.request_dir, exist_ok=True)
        os.makedirs(self.response_dir, exist_ok=True)
        
        # Save experiment configuration
        config = {
            "timestamp": timestamp,
            "exp_name": self.exp_name,
            "model": self.model,
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
            "num_tasks": self.num_tasks
        }
        with open(os.path.join(self.exp_dir, "config.json"), "w") as f:
            json.dump(config, f, indent=4)

    def create_proposal_request(self, persona: str) -> dict:
        """Create request for task proposal with persona"""
        return {
            "model": self.model,
            "messages": [
                {"role": "system", "content": self.proposal_system_prompt},
                {"role": "user", "content": f"Please propose an interesting shopping task based on this persona:\n{persona}\n\nYour task:"}
            ],
            "max_tokens": self.max_tokens,
            "temperature": self.temperature
        }

    def create_execution_request(self, task_state: TaskState) -> dict:
        """Create request for task execution with persona"""
        system_msg = f"{self.execution_system_prompt}\n\nPersona: {task_state.persona}\nTask: {task_state.instruction}\n"
        
        if task_state.current_round == 0:
            user_msg = (
                "Now here is the current page content. Based on your persona, the given task and the current web page content, "
                "give a brief reasoning and then provide a valid action. When outputting the action, "
                "please write your action after the prompt 'Action:'."
            )
        else:
            action_summary = "To complete the given task, you have taken the following actions:\n" + "\n".join(
                f"{i+1}. {a}" for i, a in enumerate(task_state.actions)
            )
            user_msg = (
                f"{action_summary}\nDo not repeat the previous actions.\n\n"
                "Now here is the new page content. Read carefully the page content. Based on your persona, the previous actions, "
                "the given task, and the current web page content, give a brief reasoning and then provide a valid action. "
                "When outputting the action, please write your action after the prompt 'Action:'."
            )

        return {
            "model": self.model,
            "messages": [
                {"role": "system", "content": system_msg},
                {"role": "user", "content": f"{user_msg}\n\nPage Content:\n{task_state.observations[-1]}\n\n"}
            ],
            "max_tokens": self.max_tokens,
            "temperature": self.temperature
        }

    async def make_api_request(self, session: aiohttp.ClientSession, request: dict) -> dict:
        """Make a single API request"""
        try:
            async with session.post(
                "https://api.openai.com/v1/chat/completions",
                headers={"Authorization": f"Bearer {self.api_key}"},
                json=request,
                timeout=30
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    logging.error(f"API error: {response.status} - {error_text}")
                    return {"choices": [{"message": {"content": f"Error: API request failed with status {response.status}"}}]}
                    
                result = await response.json()
                return result
        except asyncio.TimeoutError:
            logging.error("Timeout for API request")
            return {"choices": [{"message": {"content": "Error: API request timed out"}}]}
        except Exception as e:
            logging.error(f"Error making API request: {e}")
            return {"choices": [{"message": {"content": f"Error: API request failed: {str(e)}"}}]}

    async def process_execution_response(self, task_state: TaskState, response: dict) -> bool:
        """Process execution response and update task state"""
        try:
            content = response["choices"][0]["message"]["content"]
            
            # Extract action and rationale
            action_match = re.search(r'Action\s*:\s*(search\[[^\]]+\]|click\[[^\]]+\])', content, re.IGNORECASE)
            if not action_match:
                logging.warning(f"Task {task_state.task_id}: Failed to extract action from response")
                return False
                
            action = action_match.group(1).lower()
            action = re.sub(r'\[\s+', '[', action)
            action = re.sub(r'\s+\]', ']', action)
            
            # Extract rationale
            rationale_end_idx = content.lower().find('action :')
            if rationale_end_idx == -1:
                rationale_end_idx = content.lower().find('action:')
            rationale = content[:rationale_end_idx].strip() if rationale_end_idx > 0 else "No explicit reasoning provided."
            
            # Take step in environment
            obs, _, done, _ = task_state.env.step(action)
            obs = clean_obs(obs)
            match = re.search(r'(?i)instruction:\s*\n.*?\n(.*)', obs, re.DOTALL)
            observation = match.group(1).strip() if match else ""
            
            # Update task state
            task_state.add_step(action, observation, rationale)
            
            # Check if task is finished
            if 'click[buy now]' in action.lower() and 'buy now' in task_state.last_observation.lower():
                task_state.is_finished = True
                logging.info(f"Task {task_state.task_id} finished successfully after {task_state.current_round} rounds")
                return True
                
            return False
            
        except Exception as e:
            logging.error(f"Error processing response for task {task_state.task_id}: {e}")
            return False

    async def generate_tasks(self):
        """Main function to generate tasks asynchronously"""
        # First, propose tasks
        logging.info("Proposing tasks...")
        proposed_tasks = []
        
        async with aiohttp.ClientSession() as session:
            for i in range(self.num_tasks):
                # Select a random persona for this task
                persona = random.choice(self.personas)
                request = self.create_proposal_request(persona)
                response = await self.make_api_request(session, request)
                instruction = response["choices"][0]["message"]["content"].strip()
                proposed_tasks.append((instruction, persona))
                
                # Save proposed tasks with personas
                with open(os.path.join(self.exp_dir, "proposed_tasks.json"), "w") as f:
                    json.dump([{"instruction": inst, "persona": pers} for inst, pers in proposed_tasks], f, indent=4)
        
        logging.info(f"Proposed {len(proposed_tasks)} tasks")
        
        # Initialize task states for execution
        for i, (instruction, persona) in enumerate(proposed_tasks):
            env = gym.make('WebAgentTextEnv-v0', observation_mode='text_rich', num_products=DEBUG_PROD_SIZE)
            # Reset environment and get initial observation
            obs, _ = env.reset()
            obs = clean_obs(obs)
            observation = "\n".join(obs.split('Instruction: \n')[-1].split('\n')[1:])
            
            task_state = TaskState(
                task_id=i,
                instruction=instruction,
                persona=persona,
                env=env,
                observations=[observation],  # Initialize with first observation
                last_observation=observation  # Set initial last_observation
            )
            self.task_states[i] = task_state
        
        # Execute tasks
        round_num = 0
        pbar = tqdm(total=self.num_tasks, desc="Completed tasks")
        completed_tasks = 0
        
        while any(not state.is_complete() for state in self.task_states.values()) and round_num < 10:
            # Create requests for unfinished tasks
            requests = []
            active_task_ids = []
            for task_id, state in self.task_states.items():
                if not state.is_complete():
                    request = self.create_execution_request(state)
                    requests.append((task_id, request))
                    active_task_ids.append(task_id)
            
            if not requests:
                logging.info("No active tasks remaining")
                break
                
            logging.info(f"Round {round_num}: Processing {len(requests)} active tasks")
            
            # Process requests asynchronously
            async with aiohttp.ClientSession() as session:
                tasks = []
                for task_id, request in requests:
                    task = asyncio.create_task(
                        self.make_api_request(session, request)
                    )
                    tasks.append((task_id, task))
                
                # Wait for all requests to complete
                responses = []
                for task_id, task in tasks:
                    response = await task
                    responses.append((task_id, response))
                
                # Process responses and track newly completed tasks
                newly_completed = 0
                for task_id, response in responses:
                    if task_id in self.task_states:
                        task_state = self.task_states[task_id]
                        if not task_state.is_complete():
                            await self.process_execution_response(task_state, response)
                            if task_state.is_complete():
                                newly_completed += 1
            
            # Update progress bar
            completed_tasks += newly_completed
            pbar.update(newly_completed)
            round_num += 1
        
        pbar.close()
        
        # Save final results
        final_results = []
        for task_id, state in self.task_states.items():
            result = {
                "task_id": task_id,
                "instruction": state.instruction,
                "persona": state.persona,
                "actions": state.actions,
                "observations": state.observations,
                "rationales": state.rationales,
                "is_finished": state.is_finished,
                "rounds_taken": state.current_round
            }
            final_results.append(result)
        
        # Save final results
        with open(os.path.join(self.exp_dir, "final_results.json"), "w") as f:
            json.dump(final_results, f, indent=4)
        
        if not self.skip_jsonl:
            # Convert to JSONL for training
            logging.info("Creating training JSONL files")
            self.convert_to_jsonl(final_results)
        
        logging.info("All processing completed successfully!")

    def convert_to_jsonl(self, results):
        """Convert results to JSONL format for training"""
        # Filter successful tasks
        successful_tasks = [
            task for task in results
            if task["is_finished"] and "click[buy now]" in task["actions"][-1].lower()
        ]
        
        if not successful_tasks:
            logging.warning("No successful tasks found for JSONL conversion")
            return
        
        logging.info(f"Converting {len(successful_tasks)} tasks to JSONL format")
        
        # Define shared messages
        shared_system_msg = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

In each step, your actions are strictly limited to two types:

1. search[keywords]: Use this action only when a "[button] Search [button_]" is present in the current web page content. You must replace "keywords" with any valid search query you want to search.

2. click[HTML Element]: Use this action to click on an HTML Element in the page content. "HTML Element" can be any clickable element in the page represented inside "[button]" and "[button_]", such as an item id, action button, or attributes and options like color or size. Note that the 'HTML Element' 'must' be present in the current page content. Also, do not click the "clicked button" or "item name".

Only use search action when a "[button] Search [button_]" is present in the current web page content and otherwise, use click action (click item id, attributes like color and size, or action button).'''

        # Generate JSONL files
        jsonl_output_path = os.path.join(self.exp_dir, "train.jsonl")
        jsonl_output_path_w_as = os.path.join(self.exp_dir, "train_w_as_action.jsonl")
        
        # Generate JSONL without action sequence history
        with open(jsonl_output_path, 'w') as f:
            for task in successful_tasks:
                task_id = task["task_id"]
                instruction = task["instruction"]
                persona = task["persona"]
                
                for i, (observation, action, rationale) in enumerate(zip(task["observations"][:-1], task["actions"], task["rationales"])):
                    messages = []
                    system_msg_instance = f"{shared_system_msg}\n\nPersona: {persona}\nTask: {instruction}\n"
                    user_msg_instance = (
                        "Now here is the current page content. Based on your persona, the given task and the current web page content, "
                        "give a brief reasoning and then provide a valid action. When outputting the action, "
                        "please write your action after the prompt 'Action:'.\n\n"
                        f"Page Content:\n{observation}\n\n"
                    )
                    assistant_msg_instance = (
                        f"{rationale}\n"
                        f"Action: {action}"
                    )
                    
                    messages.append({'role': 'system', 'content': system_msg_instance})
                    messages.append({'role': 'user', 'content': user_msg_instance})
                    messages.append({'role': 'assistant', 'content': assistant_msg_instance})
                    
                    f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'messages': messages}) + '\n')
        
        # Generate JSONL with action sequence history
        with open(jsonl_output_path_w_as, 'w') as f:
            for task in successful_tasks:
                task_id = task["task_id"]
                instruction = task["instruction"]
                persona = task["persona"]
                
                for i, (observation, action, rationale) in enumerate(zip(task["observations"][:-1], task["actions"], task["rationales"])):
                    messages = []
                    system_msg_instance = f"{shared_system_msg}\n\nPersona: {persona}\nTask: {instruction}\n"
                    messages.append({'role': 'system', 'content': system_msg_instance})
                    
                    # Use action history
                    previous_actions = task["actions"][:i]
                    if previous_actions:
                        action_summary = "To complete the given task, you have taken the following actions:\n" + "\n".join(
                            f"{idx+1}. {a}" for idx, a in enumerate(previous_actions)
                        )
                        user_instruction = (
                            f"{action_summary}\nDo not repeat the previous actions.\n\n"
                            "Now here is the new page content. Read carefully the page content. Based on your persona, the previous actions, "
                            "the given task, and the current web page content, give a brief reasoning and then provide a valid action. "
                            "When outputting the action, please write your action after the prompt 'Action:'."
                        )
                    else:
                        user_instruction = (
                            "Now here is the current page content. Based on your persona, the given task and the current web page content, "
                            "give a brief reasoning and then provide a valid action. When outputting the action, "
                            "please write your action after the prompt 'Action:'."
                        )
                    
                    user_msg_instance = f"{user_instruction}\n\nPage Content:\n{observation}\n\n"
                    assistant_msg_instance = (
                        f"{rationale}\n"
                        f"Action: {action}"
                    )
                    
                    messages.append({'role': 'user', 'content': user_msg_instance})
                    messages.append({'role': 'assistant', 'content': assistant_msg_instance})
                    
                    f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'messages': messages}) + '\n')
        
        logging.info(f"JSONL files created: {jsonl_output_path} and {jsonl_output_path_w_as}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--api-key", type=str, required=True, help="OpenAI API key")
    parser.add_argument("--model", type=str, default="gpt-4", help="OpenAI model to use")
    parser.add_argument("--max-tokens", type=int, default=2048, help="Max tokens per response")
    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation")
    parser.add_argument("--num-tasks", type=int, default=100, help="Number of tasks to generate")
    parser.add_argument("--output-dir", type=str, default="alice_data_proposed", help="Output directory")
    parser.add_argument("--exp-name", type=str, default=None, help="Experiment name to append to timestamp")
    parser.add_argument("--skip-jsonl", action="store_true", help="Skip JSONL conversion")
    parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
    args = parser.parse_args()

    # Setup logging
    numeric_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError(f'Invalid log level: {args.log_level}')
    logging.basicConfig(
        level=numeric_level,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(f"alice_propose_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
        ]
    )

    proposer = OpenAITaskProposer(
        api_key=args.api_key,
        model=args.model,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        num_tasks=args.num_tasks,
        output_dir=args.output_dir,
        exp_name=args.exp_name,
        skip_jsonl=args.skip_jsonl
    )

    asyncio.run(proposer.generate_tasks())

if __name__ == "__main__":
    main() 