import gym
import json
import copy
from tqdm import tqdm
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria
import torch
from tqdm import tqdm
from web_agent_site.envs import WebAgentTextEnv
from web_agent_site.utils import DEBUG_PROD_SIZE
import json
import re
import os
import pandas as pd
import random
from peft import PeftModel
import torch
# parse args
import argparse
from prompt_template import PromptTemplate
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import logging


def clean_obs(obs):
    return obs
    # return obs.replace('[button]', '[').replace('[button_]', ']')


def load_hfmodel(ckpt=None, use_vllm=False):
    if ckpt == None:
        path = ''
    else:
        path = ckpt

    tokenizer = AutoTokenizer.from_pretrained(path,
                                              trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True

    if use_vllm:
        base_model = LLM(
            model=path,
            enable_lora=True,  # Enable LoRA support
            max_lora_rank=64,
            trust_remote_code=True,
            dtype="bfloat16",
        )
    else:
        base_model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            use_auth_token=True,
            use_flash_attention_2=True,
        ).to('cuda')

    print('Loaded Model and Tokenizer')
    return base_model, tokenizer


class EosListStoppingCriteria(StoppingCriteria):
    # [58, 4794, 60]): #32007 for phi-3 128009
    def __init__(self, eos_sequence=[32007]):
        # terminators = [
        #     self.tokenizer.eos_token_id,
        #     self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        #     ]
        self.eos_sequence = eos_sequence
        # self.eos_sequence = terminators

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:, -len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids


def hf_llm_inference(base_model, tokenizer, system_msg, user_msg, use_vllm=False, lora_dir=None):
    chat = [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': user_msg},
    ]

    if use_vllm:
        # Format prompt using tokenizer's chat template
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True
        )

        # Set sampling parameters
        sampling_params = SamplingParams(
            temperature=0.5,  # Match original temperature
            max_tokens=512,
            top_p=0.7,  # Match original top_p
            stop=[tokenizer.eos_token]
        )

        # If using LoRA, create LoRA request
        lora_request = None
        if lora_dir is not None:
            lora_request = LoRARequest(
                "agent_lora",
                1,
                lora_dir
            )

        # Generate response
        outputs = base_model.generate(
            [prompt],
            sampling_params,
            lora_request=lora_request
        )

        return outputs[0].outputs[0].text
    else:
        input_ids = tokenizer.apply_chat_template(
            chat,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
            add_special_tokens=False,
        ).to('cuda')

        # sampling
        output_ids = base_model.generate(
            input_ids,
            max_new_tokens=512,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria=[EosListStoppingCriteria()],
            num_beams=1,
            use_cache=True,
            temperature=0.5,
            top_p=0.7,
        ).squeeze(0)

        input_lens = input_ids.shape[1]
        output_ids = output_ids[input_lens:]
        obs_repr = tokenizer.decode(
            output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        return obs_repr


# For opensource model (huggingface)
class HFAgentFeedback:
    def __init__(self, model, tokenizer, eval_file, data_file=None, mode="action", save_rationale=False, use_summary=False, use_vllm=False, lora_dir=None):
        self.previous_actions = []
        self.agent = model
        self.tokenizer = tokenizer
        self.mode = mode
        self.prompt_template = PromptTemplate()
        self.save_rationale = save_rationale
        self.use_summary = use_summary
        self.use_vllm = use_vllm
        self.lora_dir = lora_dir
        self.eval_file = eval_file

        # read triplets.csv file
        gt_data = pd.read_csv('triplets.csv')
        # get the column data of "action_sequence" and make it as list
        self.few_shot_samples = gt_data['action_sequence'].tolist()

        # Load data file for ground truth action sequences if provided
        self.data_file = data_file
        self.instruction_to_actions = {}
        if self.data_file:
            self.load_ground_truth_actions()

        # Now load eval file after ground truth actions are loaded
        self.success_cases, self.fail_cases = self.read_eval_file(eval_file)

        # Generate summaries
        if self.use_summary:
            self.summary = self.generate_pattern_summary()

        # general
        self.no_persona_prompt = self.prompt_template.no_persona_prompt
        self.system_prompt = self.prompt_template.system_prompt

    def load_ground_truth_actions(self):
        """Load ground truth action sequences from data file and map them to instructions."""
        with open(self.data_file, 'r') as f:
            data = json.load(f)

        # Create a mapping from instruction to action sequences
        for item in data:
            if 'instruction' in item and 'task_results' in item:
                instruction = item['instruction']
                action_sequence = [task['action']
                                   for task in item['task_results']]

                # Store the mapping (potentially overwriting if duplicate instructions exist)
                self.instruction_to_actions[instruction] = action_sequence

        print(
            f"Loaded {len(self.instruction_to_actions)} ground truth action sequences from data file")

    def read_eval_file(self, eval_file):
        with open(eval_file, 'r') as f:
            eval_data = json.load(f)
        data_len = len(eval_data['rewards'])
        success_cases = []
        fail_cases = []

        if self.mode == "action":
            for i in range(data_len):
                instruction = eval_data['instructions'][i] if 'instructions' in eval_data else None

                # Try to find matching ground truth actions from data file
                if instruction and instruction in self.instruction_to_actions:
                    action_lst = self.instruction_to_actions[instruction]
                    action_sequence = "\n".join(action_lst)
                    if eval_data['rewards'][i] >= 0.5:
                        success_cases.append(action_sequence)
                    else:
                        fail_cases.append(action_sequence)
                # Fall back to original action sequence if no match found or no data file provided
                elif 'action_sequence' in eval_data:
                    action_lst = eval_data['action_sequence'][i]
                    action_sequence = "\n".join(action_lst)
                    if eval_data['rewards'][i] >= 0.5:
                        success_cases.append(action_sequence)
                    else:
                        fail_cases.append(action_sequence)
        elif self.mode == "instruction":
            for i in range(data_len):
                if eval_data['rewards'][i] >= 0.5:
                    success_cases.append(eval_data['instructions'][i])
                else:
                    fail_cases.append(eval_data['instructions'][i])

        print(
            f"Loaded {len(success_cases)} success cases and {len(fail_cases)} fail cases")
        return success_cases, fail_cases

    def generate_pattern_summary(self):
        """
        Generate a summary of patterns observed in success or failure cases using the LLM itself.

        Returns:
            str: A summary of patterns observed
        """
        # Load the feedback data
        with open(self.eval_file, 'r') as f:
            feedback_data = json.load(f)

        rewards = feedback_data.get("rewards", [])
        instructions = feedback_data.get("instructions", [])
        action_sequences = feedback_data.get("action_sequence", [])
        obs_sequences = feedback_data.get("obs_sequence", [])

        if not rewards or not instructions or not action_sequences or not obs_sequences:
            logging.warning("Missing required fields in feedback JSON")
            return None

        # Pair rewards with actions and instructions
        trajectory_data = []
        for i in range(len(rewards)):
            if i < len(instructions) and i < len(action_sequences) and i < len(obs_sequences):
                trajectory_data.append({
                    "reward": rewards[i],
                    "instruction": instructions[i],
                    "actions": action_sequences[i],
                    "observations": obs_sequences[i]
                })

        # Separate trajectories with zero and non-zero rewards
        non_zero_reward_trajectories = [
            t for t in trajectory_data if t["reward"] > 0]
        zero_reward_trajectories = [
            t for t in trajectory_data if t["reward"] == 0]

        # Sort non-zero reward trajectories by reward
        non_zero_reward_trajectories.sort(key=lambda x: x["reward"])

        # Get 4 random trajectories with reward < 0.5
        low_reward_trajectories = [
            t for t in non_zero_reward_trajectories if t["reward"] < 0.5]
        low_non_zero_trajectories = random.sample(
            low_reward_trajectories, min(5, len(low_reward_trajectories)))

        # Get 2 trajectories with 0 reward
        zero_reward_trajectories = zero_reward_trajectories[:2]

        # Get 2 highest reward trajectories (successful ones)
        successful_trajectories = sorted(
            non_zero_reward_trajectories, key=lambda x: x["reward"], reverse=True)[:2]

        # Combine for our failed trajectories
        failed_trajectories = low_non_zero_trajectories + zero_reward_trajectories

        # Define system message for feedback analysis
        system_msg = (
            "You are an AI assistant tasked with analyzing web shopping trajectories. "
            "To get a high reward, the model needs to complete the task with the given instruction, fulfilling the task requirements such as price, clicking on the right product, correctly clicking the attributes like size and color, etc."
            "Given trajectories of varying rewards, your task is to identify which skills the current model lacks the most:\n\n"
            "Using your feedback, you will explore the web shopping task on the next round, where your trajectories will be used to train the model."
            "Please note that during your exploration, you don't have an instruction to follow and just explore the web shopping task with your feedback."
        )

        # Create user message with trajectory examples
        user_msg = (
            "Now here are the trajectories of the current model:\n\n"
        )

        # Add successful trajectories
        user_msg += "\nSuccessful trajectories (high reward):\n"
        for i, traj in enumerate(successful_trajectories):
            user_msg += f"\nTrajectory {i+1}:\n"
            user_msg += f"Instruction: {traj['instruction']}\n"
            user_msg += f"Observation-Action Sequence:\n"

            # Create observation-action pairs
            for j in range(min(len(traj['observations']), len(traj['actions']))):
                user_msg += f"  Observation {j+1}: {traj['observations'][j][:200]}... (truncated)\n"
                user_msg += f"  Action {j+1}: {traj['actions'][j]}\n\n"

            user_msg += f"Reward: {traj['reward']}\n"

        # Add failed trajectories with observation-action sequences
        user_msg += "\nFailed trajectories (low or zero reward):\n"
        for i, traj in enumerate(failed_trajectories):
            user_msg += f"\nTrajectory {i+1}:\n"
            user_msg += f"Instruction: {traj['instruction']}\n"
            user_msg += f"Observation-Action Sequence:\n"

            # Create observation-action pairs
            for j in range(min(len(traj['observations']), len(traj['actions']))):
                user_msg += f"  Observation {j+1}: {traj['observations'][j][:200]}... (truncated)\n"
                user_msg += f"  Action {j+1}: {traj['actions'][j]}\n\n"

            user_msg += f"Reward: {traj['reward']}\n"

        # Add predefined skills and guides
        predefined_skills = [
            "Avoiding action repetition (not stuck in the middle)",
            "Performing valid actions",
            "Creating detailed search queries",
            "Selecting appropriate product attributes",
            "Completing purchases",
            "Making informed product selections",
        ]

        predefined_guides = [
            "Avoid repeating the same action in consecutive steps",
            "Only use search when the search button is present, and click only on elements that exist on the current page",
            "Use very detailed search queries (more than 5~8 words)",
            "Try to click many product attributes (size, color, etc.)",
            "Always aim to complete the shopping task by clicking the 'Buy Now' button when you find a suitable product",
            "Explore multiple products before making a selection, considering their attributes and your preferences",
        ]

        user_msg += (
            "-"*50 + "\n"
            "\nBased on these trajectories, analyze which of the following predefined skills the model lacks the most:\n\n"
            + "\n".join(f"{i+1}. {skill}\n   Guide: {guide}" for i, (skill, guide)
                        in enumerate(zip(predefined_skills, predefined_guides))) + "\n\n"
            "Select the top 2 skills that the model needs to improve the most, based on the evidence in the trajectories. "
            "Do not choose skills at random - only select skills where there is clear evidence of deficiency. "
            "For each selected skill, use a corresponding predefined guide on what you need to do during your exploration of the web shopping task. "
            "The example format could be like this: \n"
            "1. Current Model Lacks [Name of skill from the predefined list]\n"
            "   Exploration guide: [What you will do to improve this skill during exploration]\n\n"
        )

        # Generate feedback using the model
        feedback_content = hf_llm_inference(
            self.agent, self.tokenizer, system_msg, user_msg, self.use_vllm)

        return feedback_content

    def act(self, user_msg, persona_instance):
        if persona_instance:
            persona = self.prompt_template.call_persona_prompt(
                persona_instance)
        else:
            persona = self.no_persona_prompt

        # randomly sample three few-shot samples
        few_shot_samples = np.random.choice(self.few_shot_samples, 3)
        # replace \n with \n\t for better readability in samples
        few_shot_samples = [
            "\t" + sample.replace("\n", "\n\t") for sample in few_shot_samples]

        system_prompt = f'''{persona}\n{self.system_prompt}'''
        # system_prompt+=f'''\n\nHere are some examples of action sequences in this webshop:\n\nExample 1:\n {few_shot_samples[0]}\nExample 2: \n {few_shot_samples[1]}\nExample 3: \n{few_shot_samples[2]}\n\n'''

        if self.use_summary:
            system_prompt += self.prompt_template.call_summary_prompt(
                self.summary)
        else:
            random_success_cases = random.sample(self.success_cases, 4)
            random_fail_cases = random.sample(self.fail_cases, 4)

            if self.mode == "action":
                system_prompt += self.prompt_template.call_feedback_prompt(
                    random_success_cases, random_fail_cases)
            elif self.mode == "instruction":
                system_prompt += self.prompt_template.call_instruction_feedback_prompt(
                    random_success_cases, random_fail_cases)

        # if persona_instance:
        #     system_prompt += f''' At the same time, you must take actions that are consistent with the persona you have adopted.'''
        action_response = hf_llm_inference(
            self.agent, self.tokenizer, system_prompt, user_msg, self.use_vllm, self.lora_dir)
        action_response = action_response.lower()

        # Extract rationale if needed
        rationale = None
        if self.save_rationale:
            try:
                # Look for pattern like "Action: search[x]" or "Action: click[x]"
                action_match = re.search(
                    r'Action\s*:\s*(search\[[^\]]+\]|click\[[^\]]+\])', action_response, re.IGNORECASE)
                if action_match:
                    action_text = action_match.group(1).lower()
                    # Clean up spaces after '[' and before ']'
                    # Removes space after [
                    action_text = re.sub(r'\[\s+', '[', action_text)
                    action_text = re.sub(r'\s+\]', ']', action_text)
                    # Everything before "Action:" is the rationale
                    rationale_end_idx = action_response.lower().find('action :')
                    if rationale_end_idx == -1:
                        rationale_end_idx = action_response.lower().find('action:')

                    if rationale_end_idx > 0:
                        rationale = action_response[:rationale_end_idx].strip()
                    else:
                        rationale = "No explicit reasoning provided."
                else:
                    # Fallback: try to find any action pattern in the text
                    action_match = re.search(
                        r'(search\[[^\]]+\]|click\[[^\]]+\])', action_response)
                    if action_match:
                        action_text = action_match.group(0)
                        # Clean up spaces after '[' and before ']'
                        action_text = re.sub(r'\[\s+', '[', action_text)
                        action_text = re.sub(r'\s+\]', ']', action_text)
                        rationale_end_idx = action_response.lower().find(action_text)
                        if rationale_end_idx > 0:
                            rationale = action_response[:rationale_end_idx].strip(
                            )
                        else:
                            rationale = ""
                    else:
                        rationale = "Failed to extract reasoning."
                        action_text = 'fail'
            except:
                rationale = "Error extracting reasoning."
                action_text = 'fail'
        else:
            try:
                match = re.findall(
                    r'(search\[[^\]]+\]|click\[[^\]]+\])', action_response)
                action_text = match[-1]  # Select the last match
                # Removes space after [
                action_text = re.sub(r'\[\s+', '[', action_text)
                action_text = re.sub(r'\s+\]', ']', action_text)
            except:
                action_text = 'fail'

        print(f"Action: {action_text}\nRationale: {rationale}")
        if self.save_rationale:
            return action_text, rationale
        else:
            return action_text


def run_alice(num_tasks, agent, env, is_persona=False, persona_path='persona.json', output_path=None, overwrite=False, save_rationale=False, two_history=False):
    """
    Run tasks with an agent interacting with an environment and optional personas.

    Args:
        num_tasks (int): Number of tasks to execute.
        agent (object): Agent capable of acting based on prompts.
        env (object): Environment for interaction.
        is_persona (bool): Whether to use personas in the interaction.
        persona_path (str): Path to the persona file.
        output_path (str): Path to save the results.
        overwrite (bool): Whether to overwrite existing results.
        save_rationale (bool): Whether to save rationales during exploration.
        two_history (bool): Whether to use two-history context.

    Returns:
        list: A list of task results with observation-action pairs.
    """
    # check if there is already a file in output_path
    if os.path.exists(output_path):
        # just do nothing if overwrite is False
        if not overwrite:
            return

    # Load personas
    with open(persona_path) as f:
        personas = json.load(f)

    results = []  # Store results for all tasks
    tqdm.write("Running tasks...")
    pbar = tqdm(total=num_tasks)
    n_gen_task = 0
    while n_gen_task < num_tasks:
        task_idx = len(results)
        # Initialize variables for the task
        actions, observations, rationales = [], [], []
        obs, info = env.reset()
        goal = clean_obs(obs)
        observation = "\n".join(goal.split(
            'Instruction: \n')[-1].split('\n')[1:])
        persona_instance = personas[np.random.randint(
            0, len(personas))] if is_persona else ''
        num_trials = 10
        done = False
        trial = 0

        # Prepare initial prompt
        prefix = (
            "Now here is the current page content. Based on the page content, feedback and your persona, give a concise reasoning and then provide any one valid action that seems interesting. Start with searching for something based on your persona. When outputting the action, please write your action "
            "after the prompt 'Action:'."
        )
        prompt = f"{prefix}\n\nPage Content:\n{observation}\n\n"

        result_dict = {}
        result_dict['persona'] = persona_instance
        result_dict['task_id'] = task_idx
        task_results = []  # Store results for the current task

        # Multi-step decision-making loop
        while not done and trial < num_trials:
            # Agent takes action
            if save_rationale:
                action, rationale = agent.act(prompt, persona_instance)
                rationales.append(rationale)
            else:
                action = agent.act(prompt, persona_instance)
            if action == 'fail':
                break
            actions.append(action)

            # Take a step in the environment
            obs, _, _, _ = env.step(action)
            obs = clean_obs(obs)
            prev_observation = observation
            observations.append(prev_observation)
            match = re.search(
                r'(?i)instruction:\s*\n.*?\n(.*)', obs, re.DOTALL)
            observation = match.group(1).strip() if match else ""

            # Print debug information
            print(f"Action: {action}\n\nObservation:\n{observation}\n\n")

            if two_history:
                prefix = (
                    "Now here is the new page content. Based on the previous step, the given task, persona, feedback and the current web page content, you must give a concise reasoning on what action to take next and then provide a new valid action at the end. When outputting the action, please write your action "
                    "after the prompt 'Action:'."
                )
            else:
                prefix = (
                    "Now here is the current page content. Based on this page content, persona, and previous actions, you must give a concise reasoning on what action to take next and then provide a new valid action at the end. When outputting the action, please write your action "
                    "after the prompt 'Action:'."
                )

            # Update prompt based on history option
            if two_history:
                # Include only the last two actions and the last observation
                prompt = (
                    f"In the previous step:\n"
                    f"- The page showed: \n{prev_observation}\n"
                    f"- You took action: \n{actions[-1]}\n\n"
                    f"\n\n{prefix}\n\nPage Content:\n{observation}\n\n"
                )
            else:
                # Original prompt logic
                if save_rationale:
                    # Include previous rationales and actions in the prompt
                    action_history = []
                    for i, (a, r) in enumerate(zip(actions[:-1], rationales[:-1])):
                        action_history.append(
                            f"{i+1}. Reasoning: {r}\n   Action: {a}")

                    # Add the most recent action and rationale
                    if actions:
                        action_history.append(
                            f"{len(actions)}. Reasoning: {rationales[-1]}\n   Action: {actions[-1]}")

                    prompt = "Previously, based on the previous page contents and your persona, you have taken the following actions with reasoning:\n" + \
                        "\n".join(action_history) + \
                        f"\n\n{prefix}\n\nPage Content:\n{observation}\n\n"
                else:
                    prompt = "Previously, based on the previous page contents and your persona, you have taken the following actions:\n" + "\n".join(
                        f"{i+1}. {a}" for i, a in enumerate(actions)
                    ) + f"\n\n{prefix}\n\nPage Content:\n{observation}\n\n"

            # Store observation-action pair with rationale if available
            result_entry = {'observation': prev_observation, 'action': action}
            if save_rationale and rationales:
                result_entry['explore_rationale'] = rationales[-1]
            task_results.append(result_entry)

            # Check for task completion
            if ('click[buy now]' in action.lower() and 'buy now' in prev_observation.lower()):
                done = True
                result_dict['task_results'] = task_results
                break
            elif trial == num_trials - 1:
                break

            trial += 1
        if action == 'fail' or not done:
            continue
        n_gen_task += 1
        pbar.update(1)
        results.append(result_dict)  # Append task results

    with open(output_path, 'w') as f:
        json.dump(results, f, indent=4)


def convert_exp_into_instruction(model=None, tokenizer=None, output_path=None, lora_path=None, filter=True, use_vllm=False):
    # Load model and tokenizer
    # model, tokenizer = load_hfmodel(model_path, use_vllm=use_vllm)

    if not use_vllm:
        # Ensure model is on the correct 'cuda'
        model = model.to('cuda').type(torch.bfloat16)
        model.eval()

        # Load PEFT model
        peft_model = PeftModel.from_pretrained(model, lora_path)
        peft_model = peft_model.to('cuda').type(torch.bfloat16)
        peft_model.eval()

    # Load the data
    with open(output_path, 'r') as f:
        results = json.load(f)

    # First filter the results if needed to improve efficiency
    if filter:
        # Filter results to only include successful tasks
        filtered_results = [
            result for result in results
            if result['task_results'][-1]['action'] == 'click[buy now]'
        ]
        results = filtered_results

    output_results = copy.deepcopy(results)

    # Define system message
    system_msg = (
        "You are a helpful assistant trained to understand web environment and "
        "generate shopping instructions. You are given an action sequence and a "
        "final product description. Your task is to generate only an user query that will "
        "lead to the final product description."
    )

    # Inference
    for i, result in tqdm(enumerate(results), total=len(results), desc="Converting to instruction"):
        action_sequence = [task['action'] for task in result['task_results']]
        final_state = result['task_results'][-1]['observation']

        user_msg = (
            f"Now here are the given action sequence and final product description.\n\n"
            f"Action Sequence:\n{action_sequence}\n\n"
            f"Final Product Description:\n{final_state}\n\n"
            f"Considering both search keywords and product detail, please generate an user query.\n"
            f"Please put more weight on the search keywords than the product detail. Do not directly include the product name in the query and rather give a high-level description of the product.\n"
            f"Note that clicked attributes in action sequence, like size, color, and options should be included in the query. (Buy now is not an attribute)\n"
            f"Attributes without [clicked button] should not be included in the query, as they are not part of the product.\n"
            f"You should also include the price condition in the query (e.g. price lower than XX dollars).\n"
            f"You should not include any other text than the query. Randomly start the query with words 'Find me','Show me', 'I am looking for', 'I need', 'I want', or similar words.\n\n"
            f"User Query: "
        )

        if use_vllm:
            full_msg = hf_llm_inference(
                model, tokenizer, system_msg, user_msg, use_vllm=True, lora_dir=None)  # args.lora_path)
        else:
            full_msg = hf_llm_inference(
                peft_model, tokenizer, system_msg, user_msg, use_vllm=False)

        # import pdb; pdb.set_trace()
        parsed_msg = full_msg.split('assistant')[-1].strip()

        # Save the result
        output_results[i]['instruction'] = parsed_msg

    # Save the results
    if filter:
        # Save filtered results
        output_path = output_path.replace('.json', '_filtered.json')
        with open(output_path, 'w') as f:
            json.dump(output_results, f, indent=4)
    else:
        # Save all results
        with open(output_path, 'w') as f:
            json.dump(output_results, f, indent=4)

    # del model, tokenizer, peft_model


def add_post_hoc_reasoning(model, tokenizer, output_path, use_vllm=False):
    """
    Add post-hoc reasoning to task results using a LLaMA model.

    Args:
        model_path (str): Path to the LLaMA model checkpoint.
        output_path (str): Path to the results JSON file.

    Returns:
        None: Updates the results file with the rationale added.
    """
    # Load the model and tokenizer
    # model, tokenizer = load_hfmodel(model_path, use_vllm=use_vllm)

    if not use_vllm:
        model = model.to('cuda').to(torch.bfloat16).eval()

    # First filter the results to only include successful tasks
    filtered_output_path = output_path.replace('.json', '_filtered.json')
    with open(filtered_output_path, 'r') as f:
        filtered_results = json.load(f)

    # Prepare a system message template for reasoning
    system_msg = (
        "You are an AI assistant tasked with explaining actions taken in a web environment. "
        "Given the instruction you need to follow and the current observation, provide a rationale for why the 'last action' was taken to follow the instruction. "
        "You can also refer to the previous actions to provide a rationale. "
        "The rationale should naturally fit with '[your rationale]. Thus, my action is [chosen action].' "
        "You only need to provide 'your rationale' part. Be very concise and clear."
    )

    # Add rationales only to the filtered tasks
    for task in tqdm(filtered_results, desc="Adding rationales"):
        instruction = task['instruction']
        task_results = task['task_results']

        for i, result in tqdm(enumerate(task_results), total=len(task_results), desc="Adding rationales"):
            # Prepare user message
            current_observation = result['observation']
            action = result['action']
            previous_actions = [r['action'] for r in task_results[:i]]

            user_msg = (
                "Given the instruction you need to follow and the current observation, provide a rationale for why the 'last action' was taken to follow the instruction. "
                "You can also refer to the previous actions to provide a rationale. "
                "The rationale should naturally fit with '[your rationale]. Thus, my action is [chosen action].' "
                "You only need to provide 'your rationale' part. Be very concise and clear. "
                f"Now, here are the given instruction, previous actions, current observation, and the last action.\n\n"
                f"Instruction: {instruction}\n\n"
                f"Previous actions before the last action: {previous_actions}\n\n"
                f"Current observation: {current_observation}\n\n"
                f"Last action taken based on the current observation: {action}\n\n"
                "Why was this last action taken? Provide a rationale:"
            )

            # Generate rationale using the model
            rationale = hf_llm_inference(
                model, tokenizer, system_msg, user_msg, use_vllm)
            # print(f"Rationale: {rationale}")
            result['rationale'] = rationale.strip()

    # Save the updated filtered results back to the filtered output path
    with open(filtered_output_path, 'w') as f:
        json.dump(filtered_results, f, indent=4)


def convert_json_into_jsonl(output_path, filter=True):
    # reformulate the result dictionary into a jsonl format
    # each data should be in the following format
    if filter:
        output_path = output_path.replace('.json', '_filtered.json')
    with open(output_path, 'r') as f:
        data = json.load(f)

    shared_system_msg = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

        In each step, your actions are strictly limited to two types:

            1. search[keywords]: Use this action when there is a "[ Search ]" button. You must replace "keywords" with a valid search query to find the given product.

            2. click[HTML Element]: Use this action to click on HTML Element in page content. "HTML Element" can be any clickable element on the page represented inside "[" and "]", such as an item number, or action button. You must replace "HTML Element" with the element id or action name, not the item name or item number.

        You must take only one action among the above two types. Now, here is the task:\n\nTask:: \n'''

    shared_user_msg = (
        "Now here is the new page content. Based on the given task and the current web page content, give a concise reasoning and generate a valid action. When outputting the action, "
        "please write your action after the prompt 'Action:'."
    )

    shared_second_user_msg = (
        "Now here is the new page content. Based on the previous actions, the given task, and the current web page content, give a concise reasoning and generate a valid action. When outputting the action, "
        "please write your action after the prompt 'Action:'."
    )

    jsonl_ouput_path = output_path.replace('.json', '.jsonl')
    jsonl_ouput_path_w_as = output_path.replace('.json', '_w_as.jsonl')

    with open(jsonl_ouput_path, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']

            for tr in d['task_results']:
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"
                assistant_msg_instance = f"Here is my action\nAction : {tr['action']}"

                messages.append(
                    {'role': 'system', 'content': system_msg_instance})
                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append(
                    {'role': 'assistant', 'content': assistant_msg_instance})

                f.write(json.dumps({'dataset': 'webshop', 'id': task_id,
                        'persona': persona, 'messages': messages}) + '\n')

    with open(jsonl_ouput_path_w_as, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']

            for i, tr in enumerate(d['task_results']):
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                messages.append(
                    {'role': 'system', 'content': system_msg_instance})

                actions = [r['action'] for r in d['task_results'][:i]]

                if len(actions) == 0:
                    user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"
                else:
                    user_msg_instance = f"{shared_second_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"

                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append(
                    {'role': 'assistant', 'content': f"Based on the given task, previous actions, and the current web page content, here is my action\nAction : {tr['action']}"})

                f.write(json.dumps({'dataset': 'webshop', 'id': task_id,
                        'persona': persona, 'messages': messages}) + '\n')


def convert_json_into_jsonl_with_rationale(output_path, filter=True, two_history=False):
    """
    Convert the JSON data into JSONL format with rationale included in the assistant's messages.

    Args:
        output_path (str): Path to the results JSON file.
        filter (bool): Whether to filter the JSON file before conversion.
        two_history (bool): Whether to use two-history context similar to mode=="prev" in run_eval.

    Returns:
        None: Writes two JSONL files with and without action sequence history.
    """
    if filter:
        output_path = output_path.replace('.json', '_filtered.json')
    with open(output_path, 'r') as f:
        data = json.load(f)

    shared_system_msg = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

        In each step, your actions are strictly limited to two types:

            1. search[keywords]: Use this action when there is a "[ Search ]" button. You must replace "keywords" with a valid search query to find the given product.

            2. click[HTML Element]: Use this action to click on HTML Element in page content. "HTML Element" can be any clickable element on the page represented inside "[" and "]", such as an item number, or action button. You must replace "HTML Element" with the element id or action name, not the item name or item number.

        You must take only one action among the above two types. Now, here is the task:\n\nTask:: \n'''

    shared_user_msg = (
        "Now here is the new page content. Based on the given task and the current web page content, give a concise reasoning and then generate a valid action. When outputting the action, "
        "please write your action after the prompt 'Action:'."
    )

    shared_second_user_msg = (
        "Now here is the new page content. Based on the previous actions, the given task, and the current web page content, give a concise reasoning and then generate a valid action. When outputting the action, "
        "please write your action after the prompt 'Action:'."
    )

    shared_prev_user_msg = (
        "In the previous step:\n"
        "- The page showed: {prev_observation}\n"
        "- You took action: {prev_action}\n\n"
        "Now here is the new page content. Based on the previous step, the given task, and the current web page content, give a concise reasoning and then generate a valid action. When outputting the action, "
        "please write your action after the prompt 'Action:'."
    )

    jsonl_output_path = output_path.replace('.json', '.jsonl')
    jsonl_output_path_w_as = output_path.replace('.json', '_w_as.jsonl')

    # Generate JSONL without action sequence history
    with open(jsonl_output_path, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']

            for tr in d['task_results']:
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                user_msg_instance = f"{shared_user_msg}\n\nPage Content:\n{tr['observation']}\n\n"
                assistant_msg_instance = (
                    f"{tr['rationale']}Thus, my action is {tr['action']}.\n"
                    f"Action: {tr['action']}"
                )

                messages.append(
                    {'role': 'system', 'content': system_msg_instance})
                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append(
                    {'role': 'assistant', 'content': assistant_msg_instance})

                f.write(json.dumps({'dataset': 'webshop', 'id': task_id,
                        'persona': persona, 'messages': messages}) + '\n')

    # Generate JSONL with action sequence history
    with open(jsonl_output_path_w_as, 'w') as f:
        for d in data:
            task_id = d['task_id']
            persona = d['persona']
            instruction = d['instruction']

            for i, tr in enumerate(d['task_results']):
                messages = []
                system_msg_instance = f"{shared_system_msg}{instruction}\n"
                messages.append(
                    {'role': 'system', 'content': system_msg_instance})

                if two_history and i > 0:
                    # Use previous observation and action format
                    prev_observation = d['task_results'][i-1]['observation']
                    prev_action = d['task_results'][i-1]['action']
                    user_instruction = shared_prev_user_msg.format(
                        prev_observation=prev_observation,
                        prev_action=prev_action
                    )
                else:
                    # Use original format
                    actions = [r['action'] for r in d['task_results'][:i]]
                    if actions:
                        action_summary = "To complete the given task, you have taken the following actions:\n" + "\n".join(
                            f"{idx+1}. {a}" for idx, a in enumerate(actions)
                        )
                        user_instruction = (
                            f"{action_summary}\nDo not repeat the previous actions.\n\n"
                            f"{shared_second_user_msg}"
                        )
                    else:
                        user_instruction = shared_user_msg

                user_msg_instance = f"{user_instruction}\n\nPage Content:\n{tr['observation']}\n\n"
                assistant_msg_instance = (
                    f"{tr['rationale']} Thus my action is {tr['action']}\nAction: {tr['action']}"
                )

                messages.append({'role': 'user', 'content': user_msg_instance})
                messages.append(
                    {'role': 'assistant', 'content': assistant_msg_instance})

                f.write(json.dumps({'dataset': 'webshop', 'id': task_id,
                        'persona': persona, 'messages': messages}) + '\n')


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str,
                        default="meta-llama/Llama-3.2-3B-Instruct")
    parser.add_argument('--output-path', type=str,
                        default="alice_data/raw_instruction.json")
    parser.add_argument('--lora-path', type=str, default="YOUR-LORA-PATH")
    parser.add_argument('--persona-path', type=str, default='persona.json')
    parser.add_argument('--overwrite', action='store_true',
                        help='overwrite the existing file')
    parser.add_argument('--seed', type=int, default=42,
                        help='random seed for reproducibility')
    parser.add_argument('--eval-file', type=str, default='YOUR-EVAL-FILE')
    parser.add_argument('--save-rationale', action='store_true',
                        help='save rationale (ER option)')
    parser.add_argument('--num-tasks', type=int,
                        default=2000, help='number of tasks to run')
    parser.add_argument(
        '--use-summary', action='store_true', help='use summary')
    parser.add_argument('--data-file', type=str, default='YOUR_DATA_FILE')
    parser.add_argument('--use-vllm', action='store_true',
                        help='use vLLM for faster inference')
    parser.add_argument('--two-history', action='store_true',
                        help='use two history context')
    args = parser.parse_args()

    # Set random seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    args = parser.parse_args()

    env = gym.make('WebAgentTextEnv-v0',
                   observation_mode='text_rich', num_products=DEBUG_PROD_SIZE)
    print("Loading Model..")
    model, tokenizer = load_hfmodel(args.model_path, use_vllm=args.use_vllm)
    tokenizer.padding_side = 'left'
    agent = HFAgentFeedback(model, tokenizer, args.eval_file, save_rationale=args.save_rationale,
                            use_summary=args.use_summary, use_vllm=args.use_vllm, lora_dir=None)
    print("Agent Created")
    print("Running Evaluation..")
    run_alice(args.num_tasks, agent, env, is_persona=True, persona_path=args.persona_path, output_path=args.output_path,
              overwrite=args.overwrite, save_rationale=args.save_rationale, two_history=args.two_history)
    print("Evaluation Done")
    # del model, tokenizer
    # import pdb; pdb.set_trace()
    print("Converting to Instructions..")
    convert_exp_into_instruction(model, tokenizer, args.output_path,
                                 lora_path=args.lora_path, filter=True, use_vllm=args.use_vllm)
    print("Conversion Done")
    print("Adding Post Hoc Reasoning..")
    add_post_hoc_reasoning(
        model, tokenizer, args.output_path, use_vllm=args.use_vllm)
    print("Post Hoc Reasoning Done")

    print("Converting to Jsonl with Rationale..")
    convert_json_into_jsonl_with_rationale(
        args.output_path, filter=True, two_history=False)
    print("Conversion Done")
