import sys
import time
import gym
import numpy as np
from rich.markup import escape
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, StoppingCriteria
import torch
from tqdm import tqdm
from web_agent_site.envs import WebAgentTextEnv
from web_agent_site.utils import DEBUG_PROD_SIZE 
import json
import re
import os
from peft import PeftModel
from rouge_score import rouge_scorer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from argparse import ArgumentParser
import requests

def clean_obs(obs):
    # clean up the observation to remove the '[button]' and '[button_]' tokens
    return obs
    # return obs.replace('[button]', '[').replace('[button_]', ']')

class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_token_id: int):
        # **** Todo : Optimize this *****
        self.stop_token_ids = [eos_token_id, 128009, 128006]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in self.stop_token_ids:
            if all(input_ids[:, -1] == stop_id):
                return True  # Stop if the last token is the stop token
        return False


def load_hfmodel(ckpt=None, lora_dir=None, use_vllm=False):
    if ckpt == None:
        path = ''
    else:
        path = ckpt

    config = AutoConfig.from_pretrained(
            path,
            trust_remote_code=False,
            revision="main",
        )

    tokenizer = AutoTokenizer.from_pretrained(path,use_fast=True,
                                            trust_remote_code=False)
    # tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    
    tokenizer.add_eos_token = True
    tokenizer.pad_token = tokenizer.eos_token



    if use_vllm:
        base_model = LLM(
        model=path,
        enable_lora=True if lora_dir else False,
        max_lora_rank=64,
        trust_remote_code=True,
        dtype="bfloat16",
        gpu_memory_utilization=0.75,  # Adjust this value (0.6-0.8)
        tensor_parallel_size=1,  # Explicitly set to 1 if using single GPU
    )
    else:
        base_model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            config=config,
            trust_remote_code=True,
            use_auth_token=True,
            use_flash_attention_2=True,
        ).to('cuda')


        base_model.config.use_cache = False
        base_model.config.pretraining_tp = 1 

        # Load the LORA model if provided
        if lora_dir is not None:

            base_model = PeftModel.from_pretrained(base_model, lora_dir).to('cuda')

        # base_model = base_model.float()
        base_model.eval()

    print('Loaded Model and Tokenizer')

    return base_model, tokenizer

def hf_llm_inference(base_model, tokenizer, system_msg, user_msg, use_vllm=False, lora_dir=None):
    '''
    base_model: AutoModelforCausalLM
    tokenizer: AutoTokenizer
    state: str
    '''

    chat = [
        {'role': 'system', 'content': system_msg},
        {'role': 'user', 'content': user_msg},
        ]
    if use_vllm:
        # Format prompt using tokenizer's chat template
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True
        )

        # import pdb; pdb.set_trace()

        # Set sampling parameters
        sampling_params = SamplingParams(
            temperature=0.0,  # Use deterministic sampling
            max_tokens=512,
            # stop=tokenizer.eos_token
            # stop=[tokenizer.eos_token]
            stop=["</s>"]
        )

        # If using LoRA, create LoRA request
        lora_request = None
        if lora_dir is not None:
            lora_request = LoRARequest(
                "agent_lora",  # A name for this LoRA adapter
                1,  # Unique ID for this adapter
                lora_dir  # Path to LoRA weights
            )

        # Generate response

        # print(prompt)
        # import pdb; pdb.set_trace()
        # print(prompt)
        outputs = base_model.generate(
            [prompt],
            sampling_params,
            lora_request=lora_request
        )
        
        response = outputs[0].outputs[0].text
        # import pdb; pdb.set_trace()
        return response
    else:
        input_ids = tokenizer.apply_chat_template(
                                chat, 
                                tokenize=True, 
                                add_generation_prompt=True, 
                                return_tensors="pt",
                                add_special_tokens=False, 
                                ).to('cuda')
        
        output_ids = base_model.generate(
            input_ids,
            max_new_tokens=512,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria = [EosListStoppingCriteria(tokenizer.eos_token_id)],
            num_beams=1,
            use_cache=True,
            temperature=None,
            top_p = None,
        ).squeeze(0)

        input_lens = input_ids.shape[1]
        output_ids = output_ids[input_lens:]
        obs_repr = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        return obs_repr

def get_custom_rewards(final_obs, prev_observation, reward_type):
    if reward_type == "exact_match":
        if prev_observation == final_obs:
            reward = 1.0
    elif reward_type == "rouge":
        scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
        scores = scorer.score(final_obs, prev_observation)
        reward = scores['rougeL'].fmeasure
    else:
        raise ValueError(f"Unsupported reward type: {reward_type}")
    return reward

    
# For opensource model (huggingface)
class HFAgent:
    def __init__(self, model, tokenizer, lora_dir=None, use_vllm=False):
        self.agent = model
        self.tokenizer = tokenizer
        self.system_prompt = '''You are an agent with a strict task of completing a web shopping assignment based on the page content and the user's instructions.

In each step, your actions are strictly limited to two types:

1. search[keywords]: Use this action only when a "[button] Search [button_]" is present in the current web page content. You must replace "keywords" with any valid search query you want to search.

2. click[HTML Element]: Use this action to click on an HTML Element in the page content. "HTML Element" can be any clickable element in the page represented inside "[button]" and "[button_]", such as an item id, action button, or attributes and options like color or size. Note that the 'HTML Element' 'must' be present in the current page content. Also, do not click the "clicked button" or "item name".

Only use search action when a "[button] Search [button_]" is present in the current web page content and otherwise, use click action (click item id, attributes like color and size, or action button).:
'''
        self.use_vllm = use_vllm
        self.lora_dir = lora_dir

    def act(self, goal, prompt):
        system_msg = f"{self.system_prompt}\n\nTask: \n{goal}\n"
        action = hf_llm_inference(self.agent, self.tokenizer, system_msg, prompt, self.use_vllm, self.lora_dir)
        print("---- Action ----")
        print(action)
        print("---- Action Finished----")
        action = action.lower()
        try:
            match = re.findall(r'(search\[[^\]]+\]|click\[[^\]]+\])', action)
            
            action = match[-1]  # Select the last match
            # Remove spaces after opening bracket and before closing bracket
            action = re.sub(r'\[\s+', '[', action)  # Removes space after [
            action = re.sub(r'\s+\]', ']', action)  # Removes space before ]
        except:
            action = 'click[back to search]'
        return action
    
def run_eval(args, agent, env, add_CoT):
    result_dict = {}
    result_dict['rewards'] = []
    result_dict['instructions'] = []
    result_dict['action_sequence'] = []
    result_dict['obs_sequence'] = []
    if args.custom:
        with open(args.custom_path, 'r') as f:
            tasks = json.load(f)
            num_tasks = args.num_tasks if args.num_tasks < len(tasks) else len(tasks)
            tasks = tasks[:num_tasks]

    else:
        num_tasks = args.num_tasks
        
    print(f"Number of tasks: {num_tasks}")

    # Set starting task index based on validation mode
    start_idx = 500 if args.validation else 0

    for task_idx in tqdm(range(start_idx, start_idx + num_tasks), desc="Evaluating tasks"):
        obs, info = env.reset(session=task_idx)
        goal = clean_obs(obs)
        actions = []
        observations = []

        # Extract cleaned goal and initial observation
        # cleaned goal will be used in the system message
        # if custom instructions are provided, use the custom
        if args.custom:
            final_obs = tasks[task_idx]['task_results'][-1]['observation']
            cleaned_goal = tasks[task_idx]['instruction']
        else:
            cleaned_goal = goal.split('Instruction: \n')[-1].split('\n')[0]

        # breakpoint()
        result_dict['instructions'].append(cleaned_goal)
        observation = "\n".join(goal.split('Instruction: \n')[-1].split('\n')[1:])

        if add_CoT:
            prefix_term = "give a concise reasoning and then "
        else:
            prefix_term = ""
        
        prefix = (
            f"Now here is the new page content. Based on the given task and the current web page content, {prefix_term}provide a valid action. When outputting the action, "
            "please write your action after the prompt 'Action:'."
        )
        prompt = f"{prefix}\n\nPage Content:\n{observation}\n\n"

        done = False
        trial = 0

        # Start multi-step decision-making
        while not done and trial <= 10:
            prev_observation = observation
            action = agent.act(cleaned_goal, prompt)
            # if "search" in action:
            #     action = f"search[{cleaned_goal}]"
            actions.append(action)

            # Take a step in the environment
            obs, reward, done, info = env.step(action)
            obs = clean_obs(obs)

            match = re.search(r'(?i)instruction:\s*\n.*?\n(.*)', obs, re.DOTALL)
            observation = match.group(1).strip() if match else ""
            observations.append(prev_observation)
            # obs = obs.replace(f'Instruction:\n{cleaned_goal}\n', '')

            print(f"Action: {action}\n\nObservation:\n{observation}\n\n")

            # Modify prefix based on mode
            if args.mode == "base":
                prefix = (
                    f"Now here is the new page content. Based on the given task and the current web page content, {prefix_term}provide a valid action. When outputting the action, "
                    "please write your action after the prompt 'Action:'."
                )
            elif args.mode == "action":
                # Current implementation with action history
                action_summary = "To complete the given task, you have taken the following actions:\n" + "\n".join(
                    f"{i+1}. {a}" for i, a in enumerate(actions)
                )
                prefix = (
                    f"{action_summary}\nDo not repeat the previous actions.\n\n"
                    f"Now here is the new page content. Read carefully the page content. Based on the previous actions, the given task, and the current web page content, {prefix_term}provide a valid action. When outputting the action, "
                    "please write your action after the prompt 'Action:'."
                )
            elif args.mode == "prev":
                prefix = (
                    f"In the previous step:\n"
                    f"- The page showed: {prev_observation}\n"
                    f"- You took action: {actions[-1]}\n\n"
                    f"Now here is the new page content. Based on the previous step, the given task, and the current web page content, {prefix_term}generate a valid action. When outputting the action, "
                    "please write your action after the prompt 'Action:'."
                )

            prompt = f"{prefix}\n\nPage Content:\n{observation}\n\n"

            # if custom instructions are provided, use the final observation
            # to determine the reward
            # and use action 'click[buy now]' to end the task
            if args.custom:
                if action == 'click[buy now]':
                    done = True
                    reward = get_custom_rewards(final_obs, prev_observation, args.reward_type)
                else:
                    done = False


            # Record rewards
            if done:
                result_dict['rewards'].append(reward)
                result_dict['action_sequence'].append(actions)
                result_dict['obs_sequence'].append(observations)
            elif trial == 10:
                if args.custom:
                    reward = get_custom_rewards(final_obs, prev_observation, args.reward_type) 
                else:
                    reward = 0.0
                result_dict['rewards'].append(reward)
                result_dict['action_sequence'].append(actions)
                result_dict['obs_sequence'].append(observations)
            trial += 1

        # Display interim metrics
        avg_reward = sum(result_dict['rewards']) / len(result_dict['rewards'])
        sr = (np.array(result_dict['rewards']) == 1.0).mean()
        print(f"Current averaged reward: {avg_reward:.2f}")
        print(f"Current success rate: {sr:.2f}")
        print(result_dict['rewards'])

    # Compute final metrics
    avg_reward = sum(result_dict['rewards']) / len(result_dict['rewards'])
    sr = (np.array(result_dict['rewards']) == 1.0).mean()
    print(f"Current averaged reward: {avg_reward:.2f}")
    print(f"Current success rate: {sr:.2f}")
    return avg_reward, sr, result_dict


def main():
    # parse arguments
    parser = ArgumentParser()
    parser.add_argument("--model-path", type=str, default="meta-llama/Llama-3.2-3B-Instruct")
    parser.add_argument("--lora-dir", type=str, default=None)
    parser.add_argument("--custom", action="store_true", default=False, help="Use custom instructions to evaluate")
    parser.add_argument("--custom-path", type=str, default=None, help="Path to custom instructions")
    parser.add_argument("--num-tasks", type=int, default=100, help="Number of tasks to evaluate")
    parser.add_argument("--reward-type", type=str, default="rouge", help="Type of reward to use")
    parser.add_argument("--output-dir", type=str, default="results/bc_new", help="Output directory")
    parser.add_argument("--exp-name", type=str, default="", help="Experiment name")
    parser.add_argument("--add-CoT", action="store_true", default=False, help="Add CoT to the system prompt")
    parser.add_argument("--validation", action="store_true", default=False, help="Run validation mode starting from task 500")
    parser.add_argument("--use-vllm", action="store_true", default=False, help="Use vLLM to generate actions")
    parser.add_argument("--mode", type=str, default="action", choices=["base", "action", "prev"],
                       help="Mode for context display: base (no history), action (action history), prev (previous observation)")
    args = parser.parse_args()

    # lora_path = "finetuning/bob_checkpoints/bc_persona_sft_data_revision/epoch_1"

    model_path = args.model_path
    env = gym.make('WebAgentTextEnv-v0', observation_mode='text_rich', num_products=DEBUG_PROD_SIZE)
    print("Loading Model..")
    model, tokenizer = load_hfmodel(model_path, lora_dir = args.lora_dir, use_vllm=args.use_vllm)
    # model = model.type(torch.float16)
    print("Model Loaded")
    # model = model.type(torch.bfloat16)
    tokenizer.padding_side='left'
    agent = HFAgent(model, tokenizer, lora_dir=args.lora_dir, use_vllm=args.use_vllm)
    print("Agent Created")
    print("Running Evaluation..")
    avg_reward, sr, result_dict = run_eval(args, agent, env, args.add_CoT)
    # append the results to the output file
    # save the meta info to the result_dict
    result_dict['avg_reward'] = avg_reward
    result_dict['success_rate'] = sr
    result_dict['model_path'] = model_path
    result_dict['lora_dir'] = args.lora_dir
    result_dict['num_tasks'] = args.num_tasks
    result_dict['reward_type'] = args.reward_type
    result_dict['custom_path'] = args.custom_path

    directory = args.output_dir 

    # safe_lora_dir = args.lora_dir.replace("/", "_")
    if args.custom:
        safe_custom_path = args.custom_path.replace("/", "_").split(".")[0]
    else:
        safe_custom_path = "default"

    if args.validation:
        filename = f"results_{args.exp_name}_{safe_custom_path}_validation.json"
    else:
        filename = f"results_{args.exp_name}_{safe_custom_path}.json"
    full_name = os.path.join(directory, filename)

    with open(full_name, "w") as f:
        json.dump(result_dict, f)

    print(f"Saved results to {full_name}")

    # Save summarized results to eval_results.txt
    with open("./eval_results.txt", "a") as f:
        f.write(f"\nExperiment: {args.exp_name}\n")
        # validation mode
        if args.validation:
            f.write(f"Validation Mode\n")
        else:
            f.write(f"Test Mode\n")
        f.write(f"Model: {model_path}\n")
        f.write(f"LoRA: {args.lora_dir}\n")
        f.write(f"Custom Instructions: {args.custom_path if args.custom else 'default'}\n")
        f.write(f"Number of Tasks: {args.num_tasks}\n")
        f.write(f"Reward Type: {args.reward_type}\n")
        f.write(f"Average Reward: {avg_reward:.4f}\n")
        f.write(f"Success Rate: {sr:.4f}\n")
        f.write("-" * 50 + "\n")


if __name__ == '__main__':
    main()

