import gym
import json
import copy
from tqdm import tqdm
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
import torch
from tqdm import tqdm
from web_agent_site.envs import WebAgentTextEnv
from web_agent_site.utils import DEBUG_PROD_SIZE 
import json
import re
import os
import pandas as pd
from peft import PeftModel
import torch
# parse args
import random
import argparse
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest


# Set seeds for determinism
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def clean_obs(obs):
    return obs.replace('[button]', '[').replace('[button_]', ']')

def load_hfmodel(ckpt=None, use_vllm=False):
    if ckpt == None:
        path = ''
    else:
        path = ckpt

    tokenizer = AutoTokenizer.from_pretrained(path,
                                            trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True

    if use_vllm:
        base_model = LLM(
            model=path,
            enable_lora=True,  # Enable LoRA support
            max_lora_rank=64,
            trust_remote_code=True,
            dtype="bfloat16",
        )
    else:
        base_model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_auth_token=True,
            use_flash_attention_2=True,
        ).to('cuda')

    print('Loaded Model and Tokenizer')
    return base_model, tokenizer

class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [32007]):#[58, 4794, 60]): #32007 for phi-3 128009
        # terminators = [
        #     self.tokenizer.eos_token_id,
        #     self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        #     ]
        self.eos_sequence = eos_sequence
        # self.eos_sequence = terminators

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids
    
    
def hf_llm_inference(base_model, tokenizer, system_msg, user_msg, use_vllm=False, lora_dir=None):
    '''
    base_model: AutoModelforCausalLM
    tokenizer: AutoTokenizer
    state: str
    '''

    chat = [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': user_msg},
        ]

    if use_vllm:
        # Format prompt using tokenizer's chat template
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True
        )

        # Set sampling parameters
        sampling_params = SamplingParams(
            temperature=0.7,  # Match original temperature
            max_tokens=512,
            top_p=0.7,  # Match original top_p
            stop=[tokenizer.eos_token, '</s>']
        )

        # If using LoRA, create LoRA request
        lora_request = None
        if lora_dir is not None:
            lora_request = LoRARequest(
                "agent_lora",
                1,
                lora_dir
            )
        # import pdb; pdb.set_trace()
        
        # Generate response
        outputs = base_model.generate(
            [prompt],
            sampling_params,
            lora_request=lora_request
        )
        
        return outputs[0].outputs[0].text
    else:
        input_ids = tokenizer.apply_chat_template(
            chat, 
            tokenize=True, 
            add_generation_prompt=True, 
            return_tensors="pt",
            add_special_tokens=False, 
        ).to('cuda')

        # sampling
        output_ids = base_model.generate(
            input_ids,
            max_new_tokens=512,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria = [EosListStoppingCriteria()],
            num_beams=1,
            use_cache=True,
            temperature=0.5,
            top_p = 0.8,
        ).squeeze(0)

        

        input_lens = input_ids.shape[1]
        output_ids = output_ids[input_lens:]
        obs_repr = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        

        return obs_repr

# For opensource model (huggingface)
class HFAgent:
    def __init__(self, model, tokenizer, save_rationale=False, use_vllm=False, lora_dir=None):
        self.previous_actions = []
        self.agent = model
        self.tokenizer = tokenizer
        self.save_rationale = save_rationale
        self.use_vllm = use_vllm
        self.lora_dir = lora_dir
        # read triplets.csv file
        gt_data = pd.read_csv('triplets.csv')
        # get the column data of "action_sequence" and make it as list
        self.few_shot_samples = gt_data['action_sequence'].tolist()

        # You must replace "HTML Element" with the item id or action name present in the page content.
        # general 
        self.free_persona = "You are a web-shop-agent that can interact with the given shopping webpage by taking actions and finally buy something that you want. You can freely take any action and buy anything that you think is appropriate and interesting."
        self.system_prompt = '''Your actions are strictly limited to two types:

            1. search[keywords]: Use this action when a "[ Search ]" button is present in the page content. You must replace "keywords" with any valid search query you want to search.

            2. click[HTML Element]: Use this action to click on HTML Element in the page content. "HTML Element" can be any clickable element in the page represented inside "[" and "]", such as a item id, or action button. Note that item id or action name must be present in the page content. Also, do not click the "clicked button".

You must take only one action among the above two types.'''

    def act(self, user_msg, persona_instance):
        if persona_instance:
            persona = f'''You are a web-shop-agent that can interact with the webpage by taking actions. You need to buy something that you want at the end. Also, you must adopt the identity of following persona :\n            "{persona_instance}" \nYou should take actions that are consistent with the persona you have adopted.\n
            '''

        else:
            persona = self.free_persona
        
        # randomly sample three few-shot samples
        few_shot_samples = np.random.choice(self.few_shot_samples, 3)
        # replace \n with \n\t for better readability in samples
        few_shot_samples = ["\t" + sample.replace("\n", "\n\t") for sample in few_shot_samples]
   
        system_prompt = f'''{persona}\n{self.system_prompt}'''
        # system_prompt+=f'''\n\nHere are some examples of action sequences in this webshop:\n\nExample 1:\n {few_shot_samples[0]}\nExample 2: \n {few_shot_samples[1]}\nExample 3: \n{few_shot_samples[2]}\n\nDo not follow the examples exactly.\n\n'''
        
        action_response = hf_llm_inference(self.agent, self.tokenizer, system_prompt, user_msg, self.use_vllm, self.lora_dir)
        # print("---- Action Raw ----")
        # print(action_response)
        
        # Extract rationale if needed
        rationale = None
        if self.save_rationale:
            # Try to extract rationale and action separately
            try:
                # Look for pattern like "Action: search[x]" or "Action: click[x]"
                action_match = re.search(r'Action\s*:\s*(search\[[^\]]+\]|click\[[^\]]+\])', action_response, re.IGNORECASE)
                if action_match:
                    action_text = action_match.group(1).lower()
                    # Remove spaces after [ and before ]
                    action_text = re.sub(r'\[\s+', '[', action_text)
                    action_text = re.sub(r'\s+\]', ']', action_text)
                    # Everything before "Action:" is the rationale
                    rationale_end_idx = action_response.lower().find('action :')
                    if rationale_end_idx == -1:
                        rationale_end_idx = action_response.lower().find('action:')
                    
                    if rationale_end_idx > 0:
                        rationale = action_response[:rationale_end_idx].strip()
                    else:
                        rationale = "No explicit reasoning provided."
                else:
                    # Fallback: try to find any action pattern in the text
                    action_match = re.search(r'(search\[[^\]]+\]|click\[[^\]]+\])', action_response.lower())
                    if action_match:
                        action_text = action_match.group(0)
                        # Remove spaces after [ and before ]
                        action_text = re.sub(r'\[\s+', '[', action_text)
                        action_text = re.sub(r'\s+\]', ']', action_text)
                        rationale_end_idx = action_response.lower().find(action_text)
                        if rationale_end_idx > 0:
                            rationale = action_response[:rationale_end_idx].strip()
                        else:
                            rationale = ""
                    else:
                        rationale = "Failed to extract reasoning."
                        action_text = 'fail'
            except:
                rationale = "Error extracting reasoning."
                action_text = 'fail'
        else:
            action_response = action_response.lower()
            try:
                match = re.findall(r'(search\[[^\]]+\]|click\[[^\]]+\])', action_response)
                action_text = match[-1]  # Select the last match
                action_text = re.sub(r'\[\s+', '[', action_text)  # Removes space after [
                action_text = re.sub(r'\s+\]', ']', action_text)
            except:
                action_text = 'fail'
        
        if self.save_rationale:
            return action_text, rationale
        else:
            return action_text
    


def run_alice(num_tasks, agent, env, is_persona=False, persona_path='persona.json', output_path=None, overwrite=False, save_rationale=False, two_history=False):
    """
    Run tasks with an agent interacting with an environment and optional personas.

    Args:
        num_tasks (int): Number of tasks to execute.
        agent (object): Agent capable of acting based on prompts.
        env (object): Environment for interaction.
        is_persona (bool): Whether to use personas in the interaction.
        persona_path (str): Path to the persona file.
        output_path (str): Path to save the results.
        overwrite (bool): Whether to overwrite existing results.
        save_rationale (bool): Whether to save rationales during exploration.
        two_history (bool): Whether to use two-history context.

    Returns:
        list: A list of task results with observation-action pairs.
    """
    # check if there is already a file in output_path
    if os.path.exists(output_path):
        # just do nothing if overwrite is False
        if not overwrite:
            return

    # Load personas
    with open(persona_path) as f:
        personas = json.load(f)

    results = []  # Store results for all tasks
    tqdm.write("Running tasks...")
    pbar = tqdm(total=num_tasks)
    n_gen_task = 0
    while n_gen_task < num_tasks:
        task_idx = len(results)
        # Initialize variables for the task
        actions, observations, rationales = [], [], []
        obs, info = env.reset()
        goal = clean_obs(obs)
        observation = "\n".join(goal.split('Instruction: \n')[-1].split('\n')[1:])
        persona_instance = personas[np.random.randint(0, len(personas))] if is_persona else ''
        num_trials = 10
        done = False
        trial = 0

        
        prefix = (
            "Now here is the current page content. Based on the page content and your persona, give a concise reasoning and then provide any one valid action that seems interesting. Start with searching for something based on your persona. When outputting the action, please write your action "
            "after the prompt 'Action:'."
        )
        prompt = f"{prefix}\n\nPage Content:\n{observation}\n\n"

        result_dict = {}
        result_dict['persona'] = persona_instance
        result_dict['task_id'] = task_idx
        task_results = []  # Store results for the current task

        # Multi-step decision-making loop
        while not done and trial < num_trials:
            # Agent takes action
            if save_rationale:
                action, rationale = agent.act(prompt, persona_instance)
                rationales.append(rationale)
            else:
                action = agent.act(prompt, persona_instance)
                
            if action == 'fail':
                break
            actions.append(action)

            # Take a step in the environment
            obs, _, _, _ = env.step(action)
            obs = clean_obs(obs)
            prev_observation = observation
            observations.append(prev_observation)
            match = re.search(r'(?i)instruction:\s*\n.*?\n(.*)', obs, re.DOTALL)
            observation = match.group(1).strip() if match else ""

            # Print debug information
            # print(f"Action: {action}\n\nObservation:\n{observation}\n\n")
            # if save_rationale:
            #     print(f"Rationale: {rationale}\n")

            if two_history:
                prefix = (
                    "Now here is the new page content. Based on the previous step, the given task, and the current web page content, you must give a concise reasoning on what action to take next and then provide a new valid action at the end. When outputting the action, please write your action "
                    "after the prompt 'Action:'."
                )

            else:
                prefix = (
                "Now here is the current page content. Based on this page content, persona, and previous actions, you must give a concise reasoning on what action to take next and then provide a new valid action at the end. When outputting the action, please write your action "
                "after the prompt 'Action:'."
                )

            # Update prompt based on history option
            if two_history:
                # Include only the last two actions and the last observation
                prompt = (
                    f"In the previous step:\n"
                    f"- The page showed: \n{prev_observation}\n"
                    f"- You took action: \n{actions[-1]}\n\n"
                    f"\n\n{prefix}\n\nPage Content:\n{observation}\n\n"
                )
            
            else:
                # Original prompt logic
                if save_rationale:
                    action_history = []
                    for i, (a, r) in enumerate(zip(actions[:-1], rationales[:-1])):
                        action_history.append(f"{i+1}. Reasoning: {r}\n   Action: {a}")
                    
                    if actions:
                        action_history.append(f"{len(actions)}. Reasoning: {rationales[-1]}\n   Action: {actions[-1]}")
                    
                    prompt = "Previously, based on the previous page contents and your persona, you have taken the following actions with reasoning:\n" + "\n".join(action_history) + f"\n\n{prefix}\n\nPage Content:\n{observation}\n\n"
                else:
                    prompt = "Previously, based on the previous page contents and your persona, you have taken the following actions:\n" + "\n".join(
                        f"{i+1}. {a}" for i, a in enumerate(actions)
                    ) + f"\n\n{prefix}\n\nPage Content:\n{observation}\n\n"

            # Store observation-action pair with rationale if available
            result_entry = {'observation': prev_observation, 'action': action}
            if save_rationale and rationales:
                result_entry['explore_rationale'] = rationales[-1]
            task_results.append(result_entry)

            # Check for task completion
            if ('click[buy now]' in action.lower() and 'buy now' in prev_observation.lower()):
                done = True
                result_dict['task_results'] = task_results
                break

            elif trial == num_trials - 1:
                break

            trial += 1
        if action == 'fail' or not done:
            continue
        n_gen_task += 1
        pbar.update(1)
        results.append(result_dict)  # Append task results

    with open(output_path, 'w') as f:
        json.dump(results, f, indent=4)


def convert_exp_into_instruction(model=None, tokenizer=None, output_path=None, lora_path=None, filter=True, use_vllm=False):
    # Load model and tokenizer
    # model, tokenizer = load_hfmodel(model_path, use_vllm=use_vllm)
    
    if not use_vllm:
        model = model.to('cuda').type(torch.bfloat16)  # Ensure model is on the correct 'cuda'
        model.eval()

        # Load PEFT model
        peft_model = PeftModel.from_pretrained(model, lora_path)
        peft_model = peft_model.to('cuda').type(torch.bfloat16)
        peft_model.eval()

    # Load the data
    with open(output_path, 'r') as f:
        results = json.load(f)  

    # First filter the results if needed to improve efficiency
    if filter:
        # Filter results to only include successful tasks
        filtered_results = [
            result for result in results
            if result['task_results'][-1]['action'] == 'click[buy now]'
        ]
        results = filtered_results

    output_results = copy.deepcopy(results)      

    # Define system message
    system_msg = (
        "You are a helpful assistant trained to understand web environment and "
        "generate shopping instructions. You are given an action sequence and a "
        "final product description. Your task is to generate only an user query that will "
        "lead to the final product description."
    )

    # Inference
    for i, result in tqdm(enumerate(results), total=len(results), desc="Converting to instruction"):

        action_sequence = [task['action'] for task in result['task_results']]
        final_state = result['task_results'][-1]['observation']

        # user_msg = (
        #     f"Now here are the given action sequence and final observation\n\n"
        #     f"Action Sequence:\n{action_sequence}\n\n"
        #     f"Final Observation:\n{final_state}\n\n"
        #     f"Please generate an instruction that will lead to the final observation "
        #     f"without any explanation.\n\n Instruction: "
        # )
        
        # user_msg = (
        #     f"Now here is the final product description\n\n"
        #     f"Final Product:\n{final_state}\n\n"
        #     f"Please generate an user query that look for this final product.\n"
        #     f"Note that attributes, like size, color, and options, with [clicked button] should be included in the query.\n"
        #     f"Attributes without [clicked button] should not be included in the query, as they are not part of the product.\n"
        #     f"You should also include the price condition in the query (e.g. price lower than XX dollars).\n"
        #     f"Your query should be detailed and should not include any other text than the query. Start the query with 'Find me'.\n\n"
        #     f"User Query: "
        # )
        user_msg = (
            f"Now here are the given action sequence and final product description.\n\n"
            f"Action Sequence:\n{action_sequence}\n\n"
            f"Final Product Description:\n{final_state}\n\n"
            f"Considering both search keywords and product detail, please generate a user query.\nPlease put more weight on the final search query than the product detail and avoid copying and pasting the product name in the query.\n"
            f"Note that attributes, like size, color, and options, with [clicked button] should be included in the query.\n"
            f"Attributes without [clicked button] should not be included in the query, as they are not part of the product.\n"
            f"You should also include the price condition in the query (e.g. price lower than XX dollars).\n"
            f"You should not include any other text than the query. Start the query with something like 'I need', 'I want', 'Find me', 'I am looking for', etc.\n\n"
            f"User Query: "
        )

        
        
        if use_vllm:
            full_msg = hf_llm_inference(model, tokenizer, system_msg, user_msg, use_vllm=True, lora_dir=None)
        else:
            full_msg = hf_llm_inference(peft_model, tokenizer, system_msg, user_msg, use_vllm=False)

        
        
        # import pdb; pdb.set_trace()
        parsed_msg = full_msg.split('assistant')[-1].strip()

        # Save the result
        output_results[i]['instruction'] = parsed_msg

    # Save the results
    if filter:
        # Save filtered results
        output_path = output_path.replace('.json', '_filtered.json')
        with open(output_path, 'w') as f:
            json.dump(output_results, f, indent=4)
    else:
        # Save all results
        with open(output_path, 'w') as f:
            json.dump(output_results, f, indent=4)

def add_post_hoc_reasoning(model, tokenizer, output_path, use_vllm=False):
    """
    Add post-hoc reasoning to task results using a LLaMA model.

    Args:
        model_path (str): Path to the LLaMA model checkpoint.
        output_path (str): Path to the results JSON file.

    Returns:
        None: Updates the results file with the rationale added.
    """
    # Load the model and tokenizer
    # model, tokenizer = load_hfmodel(model_path, use_vllm=use_vllm)
    
    if not use_vllm:
        model = model.to('cuda').to(torch.bfloat16).eval()

    # First filter the results to only include successful tasks
    filtered_output_path = output_path.replace('.json', '_filtered.json')
    with open(filtered_output_path, 'r') as f:
        filtered_results = json.load(f)

    # Prepare a system message template for reasoning
    system_msg = (
        "You are an AI assistant tasked with explaining actions taken in a web environment. "
        "Given the instruction, current observation, and previous actions, provide a very concise rationale for why "
        "the action was taken as you are providing an action. "
        "This means that your rationale should be naturally fit with '[your rationale]. Thus, my action is [chosen action]'"
        "You only need to provide 'your rationale' part. Try to start with 'I am'. Be very concise and clear."
    )

    # Add rationales only to the filtered tasks
    for task in tqdm(filtered_results, desc="Adding rationales"):
        instruction = task['instruction']
        task_results = task['task_results']
        
        for i, result in tqdm(enumerate(task_results), total=len(task_results), desc="Adding rationales"):
            # Prepare user message
            current_observation = result['observation']
            action = result['action']
            previous_actions = [r['action'] for r in task_results[:i]]
            
            user_msg = (
                f"Instruction: {instruction}\n\n"
                f"Current Observation: {current_observation}\n\n"
                f"Previous Actions: {previous_actions}\n\n"
                f"Action Taken: {action}\n\n"
                "Why was this action taken? Provide a concise rationale:"
            )

            # Generate rationale using the model
            rationale = hf_llm_inference(model, tokenizer, system_msg, user_msg, use_vllm)
            result['rationale'] = rationale.strip()

    # Save the updated filtered results back to the filtered output path
    with open(filtered_output_path, 'w') as f:
        json.dump(filtered_results, f, indent=4)

def convert_json_into_jsonl(output_path, filter=True):
    # reformulate the result dictionary into a jsonl format
    # each data should be in the following format
    if filter:
        output_path = output_path.replace('.json', '_filtered.json')
    with open(output_path, 'r') as f:
        data = json.load(f)

    shared_system_msg = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

        In each step, your actions are strictly limited to two types:

            1. search[keywords]: Use this action when there is a "[ Search ]" button. You must replace "keywords" with a valid search query to find the given product.

            2. click[HTML Element]: Use this action to click on HTML Element in page content. "HTML Element" can be any clickable element on the page represented inside "[" and "]", such as an item number, or action button. You must replace "HTML Element" with the element to click.

        You must take only one action among the above two types. Now, here is the task:\n\nTask:: \n'''

    shared_user_msg = (
            "Now here is the new page content. Based on the given task and the current web page content, generate a valid action. When outputting the action, "
            "please write your action after the prompt 'Action:'."
        )

    shared_second_user_msg = (
                   "Now here is the new page content. Based on the previous actions, the given task, and the current web page content, generate a valid action. When outputting the action, "
                "please write your action after the prompt 'Action:'."
                )
    
    jsonl_ouput_path = output_path.replace('.json', '.jsonl')
    jsonl_ouput_path_w_as = output_path.replace('.json', '_w_as.jsonl')

    with open(jsonl_ouput_path, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']
            
            for tr in d['task_results']:
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"
                assistant_msg_instance = f"Here is my action\nAction : {tr['action']}"

                messages.append({'role': 'system', 'content': system_msg_instance})
                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append({'role': 'assistant', 'content': assistant_msg_instance})
                
                f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'persona': persona, 'messages': messages}) + '\n')

    with open(jsonl_ouput_path_w_as, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']
            
            for i, tr in enumerate(d['task_results']):
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                messages.append({'role': 'system', 'content': system_msg_instance})

                actions = [r['action'] for r in d['task_results'][:i]]

                if len(actions) == 0:
                    user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"
                else:
                    user_msg_instance = f"{shared_second_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"

                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append({'role': 'assistant', 'content': f"Based on the given task, previous actions, and the current web page content, here is my action\nAction : {tr['action']}"})
                
                f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'persona': persona, 'messages': messages}) + '\n')

def convert_json_into_jsonl_with_rationale(output_path, filter=True, two_history=False):
    """
    Convert the JSON data into JSONL format with rationale included in the assistant's messages.

    Args:
        output_path (str): Path to the results JSON file.
        filter (bool): Whether to filter the JSON file before conversion.
        two_history (bool): Whether to use two-history context similar to mode=="prev" in run_eval.

    Returns:
        None: Writes two JSONL files with and without action sequence history.
    """
    if filter:
        output_path = output_path.replace('.json', '_filtered.json')
    with open(output_path, 'r') as f:
        data = json.load(f)

    shared_system_msg = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

        In each step, your actions are strictly limited to two types:

            1. search[keywords]: Use this action when there is a "[ Search ]" button. You must replace "keywords" with a valid search query to find the given product.

            2. click[HTML Element]: Use this action to click on HTML Element in page content. "HTML Element" can be any clickable element on the page represented inside "[" and "]", such as an item number, or action button. You must replace "HTML Element" with the element id or action name, not the item name or item number.

        You must take only one action among the above two types. Now, here is the task:\n\nTask:: \n'''

    shared_user_msg = (
                "Now here is the new page content. Based on the given task and the current web page content, give a concise reasoning and then generate a valid action. When outputting the action, "
                "please write your action after the prompt 'Action:'."
            )

    shared_second_user_msg = (
                   "Now here is the new page content. Based on the previous actions, the given task, and the current web page content, give a concise reasoning and then generate a valid action. When outputting the action, "
                "please write your action after the prompt 'Action:'."
                )

    shared_prev_user_msg = (
                   "In the previous step:\n"
                   "- The page showed: {prev_observation}\n"
                   "- You took action: {prev_action}\n\n"
                   "Now here is the new page content. Based on the previous step, the given task, and the current web page content, give a concise reasoning and then generate a valid action. When outputting the action, "
                   "please write your action after the prompt 'Action:'."
                )

    jsonl_output_path = output_path.replace('.json', '.jsonl')
    if two_history:
        jsonl_output_path_w_as = output_path.replace('.json', '_w_as_two_history.jsonl')
    else:
        jsonl_output_path_w_as = output_path.replace('.json', '_w_as_action.jsonl')

    # Generate JSONL without action sequence history
    with open(jsonl_output_path, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']

            for tr in d['task_results']:
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"
                assistant_msg_instance = (
                    f"{tr['rationale']}\n"
                    f"Action: {tr['action']}"
                )

                messages.append({'role': 'system', 'content': system_msg_instance})
                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append({'role': 'assistant', 'content': assistant_msg_instance})

                f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'persona': persona, 'messages': messages}) + '\n')

    # Generate JSONL with action sequence history
    with open(jsonl_output_path_w_as, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']

            for i, tr in enumerate(d['task_results']):
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                messages.append({'role': 'system', 'content': system_msg_instance})

                if two_history and i > 0:
                    # Use previous observation and action format
                    prev_observation = d['task_results'][i-1]['observation']
                    prev_action = d['task_results'][i-1]['action']
                    user_instruction = shared_prev_user_msg.format(
                        prev_observation=prev_observation,
                        prev_action=prev_action
                    )
                else:
                    # Use original format
                    actions = [r['action'] for r in d['task_results'][:i]]
                    if actions:
                        action_summary = "To complete the given task, you have taken the following actions:\n" + "\n".join(
                            f"{idx+1}. {a}" for idx, a in enumerate(actions)
                        )
                        user_instruction = (
                            f"{action_summary}\nDo not repeat the previous actions.\n\n"
                            f"{shared_second_user_msg}"
                        )
                    else:
                        user_instruction = shared_user_msg

                user_msg_instance = f"{user_instruction}\n\nPage Content:\n{tr['observation']}\n\n"
                assistant_msg_instance = (
                    f"{tr['rationale']}\nAction: {tr['action']}"
                )

                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append({'role': 'assistant', 'content': assistant_msg_instance})

                f.write(json.dumps({'dataset': 'webshop', 'id': task_id, 'persona': persona, 'messages': messages}) + '\n')



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str, required=True, help="Path to base model")
    parser.add_argument('--output-path', type=str, default="output/raw_instruction.json", help="Path to save output")
    parser.add_argument('--lora-path', type=str, help="Path to LoRA weights")
    parser.add_argument('--persona-path', type=str, default='persona.json', help="Path to persona data")
    parser.add_argument('--overwrite', action='store_true', help='overwrite the existing file')
    parser.add_argument('--seed', type=int, default=42, help='random seed for reproducibility')
    parser.add_argument('--save-rationale', action='store_true', help='save rationale (ER option)')
    parser.add_argument('--num-tasks', type=int, default=1000, help='number of tasks')  
    parser.add_argument('--two-history', action='store_true', help='use two history')   
    parser.add_argument('--use-vllm', action='store_true', help='use vLLM for faster inference')
    args = parser.parse_args()

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    env = gym.make('WebAgentTextEnv-v0', observation_mode='text_rich', num_products=DEBUG_PROD_SIZE)
    print("Loading Model..")
    model, tokenizer = load_hfmodel(args.model_path, use_vllm=args.use_vllm)
    tokenizer.padding_side='left'
    agent = HFAgent(model, tokenizer, save_rationale=args.save_rationale, use_vllm=args.use_vllm, lora_dir=None)
    print("Agent Created")
    print("Running Evaluation..")
    run_alice(args.num_tasks, agent, env, is_persona=True, persona_path=args.persona_path, output_path=args.output_path, overwrite=args.overwrite, save_rationale=args.save_rationale, two_history=args.two_history)
    print("Evaluation Done")
    print("Converting to Instructions..")
    convert_exp_into_instruction(model, tokenizer, args.output_path, lora_path=args.lora_path, filter=True, use_vllm=args.use_vllm)
    print("Conversion Done")
    print("Adding Post Hoc Reasoning..")
    add_post_hoc_reasoning(model, tokenizer, args.output_path, args.use_vllm)
    print("Conversion Done")
    print("Converting to Jsonl with Rationale..")
    convert_json_into_jsonl_with_rationale(args.output_path, filter=True, two_history=False)
    print("Conversion Done")