import gym
import json
import random
import string
import time
import torch

from bs4 import BeautifulSoup
from bs4.element import Comment
from collections import defaultdict
from flask import Flask
from web_agent_site.engine.engine import (
    load_products,
    init_search_engine,
    get_top_n_product_from_keywords,
    map_action_to_html,
    parse_action,
    get_product_per_page,
    ACTION_TO_TEMPLATE,
    END_BUTTON, NEXT_PAGE, PREV_PAGE, BACK_TO_SEARCH,
)
from web_agent_site.engine.goal import get_reward, get_goals
from web_agent_site.utils import (
    DEFAULT_FILE_PATH,
    FEAT_CONV,
    FEAT_IDS
)

app = Flask(__name__)
class WebAgentTextEnv(gym.Env):
    """Gym environment for Text mode of WebShop environment"""
    def __init__(
            self,
            observation_mode='html',
            file_path=DEFAULT_FILE_PATH,
            server=None,
            seed=None,
            **kwargs
        ):
        """
        Constructor for text environment

        Arguments:
        observation_mode (`str`) -- ['html' | 'text'] (default 'html')
        get_image
        filter_goals
        limit_goals
        num_products
        human_goals
        session
        session_prefix
        show_attrs
        """
        super(WebAgentTextEnv, self).__init__()
        self.observation_mode = observation_mode
        self.kwargs = kwargs

        # add randomness configuration
        self.enable_random_feedback = self.kwargs.get('enable_random_feedback', True)
        self.random_feedback_probability = self.kwargs.get('random_feedback_probability', 0.3)
        
        # unified random number generator (reproducible)
        self.seed = seed if seed is not None else self.kwargs.get('seed')
        self.rng = random.Random(self.seed) if self.seed is not None else random.Random()
        
        # randomization mode configuration
        allowed_modes = {'replace', 'append', 'prefix', 'typo', 'null'}
        self.randomization_mode = self.kwargs.get('randomization_mode', 'append')
        if self.randomization_mode not in allowed_modes:
            self.randomization_mode = 'append'
        self.randomization_count = 0
        
        # Null feedback configuration
        self.enable_null_feedback = self.kwargs.get('enable_null_feedback', False)
        self.null_feedback_rounds = self.kwargs.get('null_feedback_rounds', [])  # specify which conversation rounds return null
        self.current_conversation_round = 0  # current conversation round count
        self.null_feedback_probability = self.kwargs.get('null_feedback_probability', 0.1)  # random null probability
        
        # add one-time random control
        self.one_time_random = self.kwargs.get('one_time_random', True)  # whether to random only once
        self.has_randomized = False  # whether to have been randomized
        self.random_search_keywords = [  # predefined random search keywords
            "women's running shoes",
            "laptop computer", 
            "coffee maker",
            "wireless headphones",
            "smartphone case",
            "kitchen blender",
            "yoga mat",
            "desk lamp",
            "backpack",
            "sunglasses"
        ]
        
        self.file_path = file_path

        self.base_url = 'http://127.0.0.1:3000'
        self.server = SimServer(
            self.base_url,
            self.file_path,
            self.kwargs.get('filter_goals'),
            self.kwargs.get('limit_goals', -1),
            self.kwargs.get('num_products'),
            self.kwargs.get('human_goals'),
            self.kwargs.get('show_attrs', False),
            rng=self.rng,
        ) if server is None else server
        self.browser = SimBrowser(self.server)

        self.session = self.kwargs.get('session')
        self.session_prefix = self.kwargs.get('session_prefix')
        if self.kwargs.get('get_image', 0):
            self.feats = torch.load(FEAT_CONV)
            self.ids = torch.load(FEAT_IDS)
            self.ids = {url: idx for idx, url in enumerate(self.ids)}
        self.prev_obs = []
        self.prev_actions = []
        self.num_prev_obs = self.kwargs.get('num_prev_obs', 0)
        self.num_prev_actions = self.kwargs.get('num_prev_actions', 0)
        self.reset()

    def step(self, action):
        """
        Takes an action, updates WebShop environment, and returns (observation, reward, done, info)

        Arguments:
        action (`str`): An action should be of the following structure:
          - search[keywords]
          - click[value]
        If action not valid, perform nothing.
        """
        info = None
        self.get_available_actions()

        # Determine action type (click, search) and argument
        action_name, action_arg = parse_action(action)
        if action_arg is not None:
            action_arg = action_arg.lower()
        

        self.current_conversation_round += 1
        

        info = {
            'randomization_applied': False,
            'original_search': action_arg if action_name == 'search' else None,
            'effective_search': action_arg if action_name == 'search' else None,
            'randomization_mode': self.randomization_mode if action_name == 'search' else None,
            'random_count_in_episode': self.randomization_count,
            'feedback_text': None,
            'null_feedback_applied': False,
            'current_conversation_round': self.current_conversation_round,
        }
            

        should_apply_null = self._should_apply_null_feedback()
        
        if should_apply_null:
            info.update({
                'null_feedback_applied': True,
                'original_action': action,
                'effective_action': None,
                'randomization_mode': 'null',
                'feedback_text': f"Null feedback applied in {self.current_conversation_round}th conversation: action ignored",
                'current_conversation_round': self.current_conversation_round,
            })
            print(f"Null feedback applied in {self.current_conversation_round}th conversation: action ignored")
            
            # directly return null observation, do not execute any action, also do not record action
            ob = self._get_null_observation()
            text_list = [ob]
            # note: here we do not record action, because action should not be executed when null feedback is applied
            for i in range(1, 1 + max(self.num_prev_obs, self.num_prev_actions)):
                if len(self.prev_actions) >= i and self.num_prev_actions >= i:
                    text_list.append(self.prev_actions[-i])
                if len(self.prev_obs) >= i and self.num_prev_obs >= i:
                    text_list.append(self.prev_obs[-i])
            state = ' [SEP] '.join(text_list[::-1])
            self.prev_obs.append(ob)
            return state, 0, False, info
        
        # handle search action (normal case)
        if action_name == 'search' and action_arg is not None:
            self.current_search_round += 1
            info['current_search_round'] = self.current_search_round
            
            if (self.enable_random_feedback and 
                (not self.one_time_random or (self.one_time_random and not self.has_randomized)) and
                self.rng.random() < self.random_feedback_probability):
                # apply other randomization
                effective_arg, feedback_text = self._apply_randomization(action_arg)
                info.update({
                    'randomization_applied': True,
                    'original_search': action_arg,
                    'effective_search': effective_arg,
                    'randomization_mode': self.randomization_mode,
                    'feedback_text': feedback_text,
                })
                action_arg = effective_arg
                self.randomization_count += 1
                info['random_count_in_episode'] = self.randomization_count
                
                if self.one_time_random:
                    self.has_randomized = True
                    info['one_time_random_completed'] = True
                
                print(f"Randomized search ({self.randomization_mode}): '{info['original_search']}' -> '{info['effective_search']}'")
        
        if (action_name == 'search' and 
            action_arg is not None and 
            action_arg != ''):
            status = self.browser.search(action_arg)
        elif (action_name == 'click' and 
              action_arg in self.text_to_clickable.keys() and 
              action_arg != 'search'):
            status = self.browser.click(action_arg, self.text_to_clickable)
        else:
            status = dict(reward=0, done=False)

        # Update observation, state with the new action
        ob = self.observation
        text_list = [ob]
        self.prev_actions.append(action)
        for i in range(1, 1 + max(self.num_prev_obs, self.num_prev_actions)):
            if len(self.prev_actions) >= i and self.num_prev_actions >= i:
                text_list.append(self.prev_actions[-i])
            if len(self.prev_obs) >= i and self.num_prev_obs >= i:
                text_list.append(self.prev_obs[-i])
        state = ' [SEP] '.join(text_list[::-1])
        self.prev_obs.append(ob)
        return state, status['reward'], status['done'], info

    def get_available_actions(self):
        """Returns list of available actions at the current step"""
        html_obj = self._parse_html()

        # Collect search bar, buttons, links, and options as clickables
        search_bar = html_obj.find(id='search_input')
        has_search_bar = True if search_bar is not None else False
        buttons = html_obj.find_all(class_='btn')
        product_links  = html_obj.find_all(class_='product-link')
        buying_options = html_obj.select('input[type="radio"]')

        self.text_to_clickable = {
            f'{b.get_text()}'.lower(): b
            for b in buttons + product_links
        }
        for opt in buying_options:
            opt_value = opt.get('value')
            self.text_to_clickable[f'{opt_value}'] = opt
        return dict(
            has_search_bar=has_search_bar,
            clickables=list(self.text_to_clickable.keys()),
        )
    
    def get_image(self):
        """Scrape image from page HTML and return as a list of pixel values"""
        html_obj = self._parse_html(self.browser.page_source)
        image_url = html_obj.find(id='product-image')
        if image_url is not None:
            image_url = image_url['src']
            if image_url in self.ids:
                image_idx = self.ids[image_url]
                image = self.feats[image_idx]
                return image
        return torch.zeros(512)

    def get_instruction_text(self):
        """Get corresponding instruction text for current environment session"""
        html_obj = self._parse_html(self.browser.page_source)
        instruction_text = html_obj.find(id='instruction-text').h4.text
        return instruction_text

    def _parse_html(self, html=None):
        """
        Returns web request result wrapped in BeautifulSoup object

        Arguments:
        url (`str`): If no url or html is provided, use the current
            observation (HTML) for parsing.
        """
        if html is None:
            html = self.state['html']
        html_obj = BeautifulSoup(html, 'html.parser')
        return html_obj
    
    @property
    def observation(self):
        """Compiles state into either the `html` or `text` observation mode"""
        html = self.state['html']
        if self.observation_mode == 'html':
            return html
        elif self.observation_mode == 'text':
            return self.convert_html_to_text(html, simple=True)
        elif self.observation_mode == 'text_rich':
            return self.convert_html_to_text(html, simple=False)
        elif self.observation_mode == 'url':
            return self.state['url']
        else:
            raise ValueError(
                f'Observation mode {self.observation_mode} not supported.'
            )
    
    @property
    def state(self):
        """
        State that includes all information. The actual observation are
        likely to be a subset or reduced form of the state.
        """
        return dict(
            url=self.browser.current_url,
            html=self.browser.page_source,
            instruction_text=self.instruction_text,
        )
    
    def convert_html_to_text(self, html, simple=False):
        """Strip HTML of tags and add separators to convert observation into simple mode"""
        texts = self._parse_html(html).findAll(text=True)
        visible_texts = filter(tag_visible, texts)
        if simple:
            # For `simple` mode, return just [SEP] separators
            return ' [SEP] '.join(t.strip() for t in visible_texts if t != '\n')
        else:
            # Otherwise, return an observation with tags mapped to specific, unique separators
            observation = ''
            for t in visible_texts:
                if t == '\n': continue
                if t.parent.name == 'button':  # button
                    processed_t = f'[button] {t} [button_]'
                elif t.parent.name == 'label':  # options
                    if f'"{t}"' in self.state['url']:
                        processed_t = f'  [clicked button] {t} [clicked button_]'
                        observation = f'You have clicked {t}.\n' + observation
                    else:
                        processed_t = f'  [button] {t} [button_]'
                elif t.parent.get('class') == ["product-link"]: # product asins
                    if f'{t}' in self.server.user_sessions[self.session]['asins']:
                        processed_t = f'\n[clicked button] {t} [clicked button_]'
                    else:
                        processed_t = f'\n[button] {t} [button_]'
                else: # regular, unclickable text
                    processed_t =  str(t)
                observation += processed_t + '\n'
            return observation
    
    def reset(self, session=None, instruction_text=None):
        """Create a new session and reset environment variables"""
        session_int = None
        if session is not None:
            self.session = str(session)
            if isinstance(session, int):
                session_int = session
        else:
            self.session = ''.join(self.rng.choices(string.ascii_lowercase, k=10))
        if self.session_prefix is not None:
            self.session = self.session_prefix + self.session

        init_url = f'{self.base_url}/{self.session}'
        self.browser.get(init_url, session_id=self.session, session_int=session_int)

        self.text_to_clickable = None
        self.instruction_text = self.get_instruction_text() if instruction_text is None else instruction_text
        obs = self.observation
        self.prev_obs = [obs]
        self.prev_actions = []
        
        # reset randomization state and count - each new task starts fresh
        if self.one_time_random:
            self.has_randomized = False
        self.randomization_count = 0
        self.current_search_round = 0  
        self.current_conversation_round = 0 
        print("New task started, randomization state and count reset")
        
        return obs, {}

    def render(self, mode='human'):
        pass

    def close(self):
        pass
    
    def reset_randomization(self):
        self.has_randomized = False
        print("Randomization state reset")

    def set_randomization(self, enable=True, one_time=True):
        self.enable_random_feedback = enable
        self.one_time_random = one_time
        if not one_time:
            self.has_randomized = False
        print(f"Randomization settings: enable={enable}, one-time={one_time}")

    def add_random_keywords(self, keywords):
        if isinstance(keywords, str):
            keywords = [keywords]
        self.random_search_keywords.extend(keywords)
        print(f"Added random keywords: {keywords}")

    def set_null_feedback(self, enable=True, rounds=None, probability=None):
        """set null feedback parameters
        
        Args:
            enable (bool): whether to enable null feedback
            rounds (list): specify which rounds to apply null feedback, e.g. [1, 3, 5]
            probability (float): random null feedback probability, between 0 and 1
        """
        self.enable_null_feedback = enable
        if rounds is not None:
            self.null_feedback_rounds = rounds
        if probability is not None:
            self.null_feedback_probability = max(0, min(1, probability))
        print(f"Null feedback settings: enable={enable}, specified rounds={self.null_feedback_rounds}, random probability={self.null_feedback_probability}")

    def add_null_feedback_rounds(self, rounds):
        if isinstance(rounds, int):
            rounds = [rounds]
        self.null_feedback_rounds.extend(rounds)
        self.null_feedback_rounds = list(set(self.null_feedback_rounds)) 
        print(f"Added null feedback rounds: {rounds}, current rounds: {self.null_feedback_rounds}")

    def remove_null_feedback_rounds(self, rounds):
        if isinstance(rounds, int):
            rounds = [rounds]
        self.null_feedback_rounds = [r for r in self.null_feedback_rounds if r not in rounds]
        print(f"Removed null feedback rounds: {rounds}, current rounds: {self.null_feedback_rounds}")

    def clear_null_feedback_rounds(self):
        self.null_feedback_rounds = []
        print("Cleared all null feedback rounds")

    def get_null_feedback_status(self):
        return {
            'enabled': self.enable_null_feedback,
            'rounds': self.null_feedback_rounds,
            'probability': self.null_feedback_probability,
            'current_round': self.current_search_round,
        }

    def _apply_randomization(self, original_arg):
        mode = self.randomization_mode
        if mode == 'replace':
            kw = self.rng.choice(self.random_search_keywords)
            return kw, f"Replaced search with random keyword: {kw}"
        elif mode == 'append':
            kw = self.rng.choice(self.random_search_keywords)
            effective = f"{original_arg} {kw}".strip()
            return effective, f"Appended random keyword: {kw}"
        elif mode == 'prefix':
            kw = self.rng.choice(self.random_search_keywords)
            effective = f"{kw} {original_arg}".strip()
            return effective, f"Prefixed random keyword: {kw}"
        elif mode == 'typo':
            return self._inject_typo(original_arg)
        else:
            return original_arg, None

    def _inject_typo(self, text):
        if not text:
            return text, None
        ops = ['swap', 'replace', 'delete']
        op = self.rng.choice(ops)
        s = list(text)
        if op == 'swap' and len(s) > 1:
            i = self.rng.randrange(0, len(s) - 1)
            s[i], s[i+1] = s[i+1], s[i]
            return ''.join(s), f"Introduced swap typo at position {i}"
        elif op == 'replace' and len(s) > 0:
            i = self.rng.randrange(0, len(s))
            s[i] = self.rng.choice(string.ascii_lowercase)
            return ''.join(s), f"Replaced character at position {i}"
        elif op == 'delete' and len(s) > 0:
            i = self.rng.randrange(0, len(s))
            del s[i]
            return ''.join(s), f"Deleted character at position {i}"
        return text, None

    def _should_apply_null_feedback(self):
        if not self.enable_null_feedback:
            return False
        

        if self.null_feedback_rounds and self.current_conversation_round in self.null_feedback_rounds:
            return True
        

        if self.null_feedback_probability > 0 and self.rng.random() < self.null_feedback_probability:
            return True
        
        return False

    def _get_null_observation(self):

        return "null"


def tag_visible(element):
    ignore = {'style', 'script', 'head', 'title', 'meta', '[document]'}
    return (
        element.parent.name not in ignore and not isinstance(element, Comment)
    )


class SimServer:
    """Lightweight simulator of WebShop Flask application for generating HTML observations"""
    def __init__(
        self,
        base_url,
        file_path,
        filter_goals=None,
        limit_goals=-1,
        num_products=None,
        human_goals=0,
        show_attrs=False,
        rng=None,
    ):
        """
        Constructor for simulated server serving WebShop application
        
        Arguments:
        filter_goals (`func`) -- Select specific goal(s) for consideration based on criteria of custom function
        limit_goals (`int`) -- Limit to number of goals available
        num_products (`int`) -- Number of products to search across
        human_goals (`bool`) -- If true, load human goals; otherwise, load synthetic goals
        """
        # Load all products, goals, and search engine
        self.base_url = base_url
        self.rng = rng if rng is not None else random.Random(233)
        self.all_products, self.product_item_dict, self.product_prices, _ = \
            load_products(filepath=file_path, num_products=num_products, human_goals=human_goals)
        self.search_engine = init_search_engine(num_products=num_products)
        self.goals = get_goals(self.all_products, self.product_prices, human_goals)
        self.show_attrs = show_attrs
        print(f'Loaded {len(self.goals)} goals.')


        # HBY: Fix outcome for random shuffling of goals (use local RNG to avoid global state)
        self.rng.shuffle(self.goals)

        # Apply `filter_goals` parameter if exists to select speific goal(s)
        if filter_goals is not None:
            self.goals = [
                goal for (i, goal) in enumerate(self.goals)
                if filter_goals(i, goal)
            ]
        
        # Imposes `limit` on goals via random selection
        if limit_goals != -1 and limit_goals < len(self.goals):
            print("has limit")
            self.weights = [goal['weight'] for goal in self.goals]
            self.cum_weights = [0]
            for w in self.weights:
                self.cum_weights.append(self.cum_weights[-1] + w)
            idxs = []
            while len(idxs) < limit_goals:
                idx = self._random_idx(self.cum_weights)
                if idx not in idxs:
                    idxs.append(idx)
            self.goals = [self.goals[i] for i in idxs]
        print(f'Loaded {len(self.goals)} goals.')

        # Set extraneous housekeeping variables
        self.weights = [goal['weight'] for goal in self.goals]
        self.cum_weights = [0]
        for w in self.weights:
            self.cum_weights.append(self.cum_weights[-1] + w)
        self.user_sessions = dict()
        self.search_time = 0
        self.render_time = 0
        self.sample_time = 0
        self.assigned_instruction_text = None  # TODO: very hacky, should remove
        
    @app.route('/', methods=['GET', 'POST'])
    def index(self, session_id, **kwargs):
        """Redirect to the search page with the given session ID"""
        html = map_action_to_html(
            'start',
            session_id=session_id,
            instruction_text=kwargs['instruction_text'],
        )
        url = f'{self.base_url}/{session_id}'
        return html, url
    
    @app.route('/', methods=['GET', 'POST'])
    def search_results(self, session_id, **kwargs):
        """Initialize session and return the search results page"""
        session = self.user_sessions[session_id]
        keywords = kwargs['keywords']  # TODO: why is this using kwargs? why not session?
        assert isinstance(keywords, list)
        page = 1 if 'page' not in kwargs else kwargs['page']
        session["page"] = page
        session["keywords"] = keywords
        session["actions"]["search"] += 1
        session["asin"] = None
        session["options"] = {}

        # Perform search on keywords from items and record amount of time it takes
        old_time = time.time()
        top_n_products = get_top_n_product_from_keywords(
            keywords,
            self.search_engine,
            self.all_products,
            self.product_item_dict,
        )
        self.search_time += time.time() - old_time
        
        # Get product list from search result asins and get list of corresponding URLs
        products = get_product_per_page(top_n_products, page)

        keywords_url_string = '+'.join(keywords)
        url = (
            f'{self.base_url}/search_results/{session_id}/'
            f'{keywords_url_string}/{page}'
        )

        # Render HTML search page and record amount of time taken
        old_time = time.time()
        html = map_action_to_html(
            'search',
            session_id=session_id,
            products=products,
            keywords=session["keywords"],
            page=page,
            total=len(top_n_products),
            instruction_text=session["goal"]["instruction_text"],
        )
        self.render_time += time.time() - old_time
        return html, url
    
    @app.route('/', methods=['GET', 'POST'])
    def item_page(self, session_id, **kwargs):
        """Render and return the HTML for a product item page"""
        session = self.user_sessions[session_id]
        clickable_name = kwargs['clickable_name']
        text_to_clickable = kwargs['text_to_clickable']
        clickable = text_to_clickable[clickable_name]

        # Update session logs with information of last product asin selected
        if (clickable.get('class') is not None and
            clickable.get('class')[0] == 'product-link'):
            session["asin"] = clickable_name.upper()
            session["actions"]["asin"] += 1
            session["asins"].add(session["asin"])
        elif clickable.get('name') is not None:
            clickable_key = clickable['name'].lower()
            session["options"][clickable_key] = clickable_name
            session["actions"]["options"] += 1

        # Set fields + url of page, then render page's HTML
        product_info = self.product_item_dict[session["asin"]]
        keywords_url_string = '+'.join(session["keywords"])
        option_string = json.dumps(session['options'])

        url = (
            f'{self.base_url}/item_page/{session_id}/'
            f'{session["asin"]}/{keywords_url_string}/'
            f'{session["page"]}/{option_string}'
        )

        html = map_action_to_html(
            'click',
            session_id=session_id,
            product_info=product_info,
            keywords=session["keywords"],
            page=session["page"],
            asin=session["asin"],
            options=session["options"],
            instruction_text=session["goal"]["instruction_text"],
            show_attrs=self.show_attrs,
        )
        return html, url

    @app.route('/', methods=['GET', 'POST'])
    def item_sub_page(self, session_id, **kwargs):
        """Render and return the HTML for a product's sub page (i.e. description, features)"""
        session = self.user_sessions[session_id]
        clickable_name = kwargs['clickable_name']
        for k in ACTION_TO_TEMPLATE:
            if clickable_name.lower() == k.lower():
                clickable_name = k
                break
        
        # Set fields + url of page, then render page's HTML
        product_info = self.product_item_dict[session["asin"]]
        session["actions"][clickable_name] += 1
        keywords_url_string = '+'.join(session["keywords"])
        url = (
            f'{self.base_url}/item_sub_page/{session_id}/'
            f'{session["asin"]}/{keywords_url_string}/{session["page"]}/'
            f'{clickable_name}/{session["options"]}'
        )
        html = map_action_to_html(
            f'click[{clickable_name}]',
            session_id=session_id,
            product_info=product_info,
            keywords=session["keywords"],
            page=session["page"],
            asin=session["asin"],
            options=session["options"],
            instruction_text=session["goal"]["instruction_text"],
        )
        return html, url

    @app.route('/', methods=['GET', 'POST'])
    def done(self, session_id, **kwargs):
        """Render and return HTML for done page"""
        session = self.user_sessions[session_id]
        goal = self.user_sessions[session_id]['goal']
        purchased_product = self.product_item_dict[session["asin"]]
        session["actions"]["purchase"] += 1
        price = self.product_prices.get(session["asin"])

        # Calculate reward for selected product and set variables for page details
        reward, info = get_reward(
            purchased_product,
            goal,
            price=price,
            options=session["options"],
            verbose=True
        )

        self.user_sessions[session_id]['verbose_info'] = info
        self.user_sessions[session_id]['done'] = True
        self.user_sessions[session_id]['reward'] = reward

        url = (
            f'{self.base_url}/done/{session_id}/'
            f'{session["asin"]}/{session["options"]}'
        )
        html = map_action_to_html(
            f'click[{END_BUTTON}]',
            session_id=session_id,
            reward=reward,
            asin=session["asin"],
            options=session["options"],
            instruction_text=session["goal"]["instruction_text"],
        )
        return html, url, reward
    
    def receive(self, session_id, current_url, session_int=None, **kwargs):
        """Map action to the corresponding page"""
        status = dict(reward=0.0, done=False)

        with app.app_context(), app.test_request_context():
            # Create/determine goal, instruction_text from current session
            if session_id not in self.user_sessions:
                idx = session_int if (session_int is not None and isinstance(session_int, int)) else self._random_idx(self.cum_weights)
                print(f"---------------1----------------")
                goal = self.goals[idx]
                instruction_text = goal['instruction_text']
                self.user_sessions[session_id] = {'goal': goal, 'done': False}
            else:
                print(f"---------------2----------------")
                instruction_text = \
                    self.user_sessions[session_id]['goal']['instruction_text']
            if self.assigned_instruction_text is not None:
                print(f"---------------3----------------")
                instruction_text = self.assigned_instruction_text  # TODO: very hacky, should remove
                self.user_sessions[session_id]['goal']['instruction_text'] = instruction_text
            session = self.user_sessions[session_id]

            if not kwargs:
                print(f"---------------4----------------")
                # If no action, reset the session variables
                kwargs['instruction_text'] = instruction_text
                html, url = self.index(session_id, **kwargs)
                self.user_sessions[session_id].update(
                    {
                        'keywords': None,
                        'page': None,
                        'asin': None,
                        'asins': set(),
                        'options': dict(),
                        'actions': defaultdict(int)
                    }
                )
            elif 'keywords' in kwargs:
                # If search keywords are available, run a search
                html, url = self.search_results(session_id, **kwargs)
            elif 'clickable_name' in kwargs:
                clickable_name = kwargs['clickable_name'].lower()
                if clickable_name == END_BUTTON.lower():
                    # If "buy now" clicked, calculate reward and flag session as terminated
                    html, url, reward = self.done(session_id, **kwargs)
                    status['reward'] = reward
                    status['done'] = True
                elif clickable_name == BACK_TO_SEARCH.lower():
                    # If "back to search" clicked, recursively reset the session back to search page
                    html, url, status = self.receive(session_id, current_url)
                elif (clickable_name == NEXT_PAGE.lower() and 
                      self.get_page_name(current_url) == 'search_results'):
                    # If "next page" clicked from search results, re-render with `page` enumerated
                    html, url, status = self.receive(
                        session_id,
                        current_url,
                        keywords=session["keywords"],
                        page=session["page"] + 1,
                    )
                elif (clickable_name == PREV_PAGE.lower() and 
                      self.get_page_name(current_url) == 'search_results'):
                    # If "prev page" clicked from search results, re-render with `page` denumerated
                    html, url, status = self.receive(
                        session_id,
                        current_url,
                        keywords=session["keywords"],
                        page=session["page"] - 1,
                    )
                elif (clickable_name == PREV_PAGE.lower() and 
                      self.get_page_name(current_url) == 'item_sub_page'):
                    # If "prev page" clicked from sub page, return to corresponding item page
                    html, url = self.item_page(session_id, **kwargs)
                elif (clickable_name == PREV_PAGE.lower() and 
                      self.get_page_name(current_url) == 'item_page'):
                    # If "prev page" clicked from item page, return to search results page
                    html, url = self.search_results(
                        session_id,
                        keywords=session["keywords"],
                        page=session["page"],
                        **kwargs
                    )
                elif clickable_name in [k.lower() for k in ACTION_TO_TEMPLATE]:
                    # Render item_sub_page if clickable is description, features, or reviews
                    html, url = self.item_sub_page(session_id, **kwargs)
                else:
                    # Otherwise, render current item page
                    html, url = self.item_page(session_id, **kwargs)
            return html, url, status
    
    def get_page_name(self, url):
        """Determine which page (i.e. item_page, search_results) the given URL is pointing at"""
        if url is None:
            return None
        page_names = [
            'search_results',
            'item_page',
            'item_sub_page',
            'done'
        ]
        for page_name in page_names:
            if page_name in url:
                return page_name
        return ''  # index page

    def _random_idx(self, cum_weights):
        """Sample an index according to cumulative weights using the local RNG."""
        total = cum_weights[-1]
        if total <= 0:
            return 0
        r = self.rng.uniform(0, total)
        for i in range(len(cum_weights) - 1):
            if cum_weights[i] <= r < cum_weights[i + 1]:
                return i
        return len(cum_weights) - 2  # fallback to last valid index


class SimBrowser:
    """Simulated browser for rendering the HTML source of WebShop environment pages"""
    def __init__(self, server):
        self.server = server
        self.current_url = None
        self.page_source = None
        self.session_id = None

    def get(self, url, session_id=None, session_int=None):
        """Set browser variables to corresponding link, page HTML for URL"""
        self.session_id = url.split('/')[-1] if session_id is None else session_id
        self.page_source, _, _ = \
            self.server.receive(self.session_id, self.current_url, session_int=session_int)
        self.current_url = url
    
    def click(self, clickable_name, text_to_clickable):
        """Wrapper for `receive` handler for performing click action on current page"""
        self.page_source, self.current_url, status = \
            self.server.receive(
                self.session_id,
                current_url=self.current_url,
                clickable_name=clickable_name,
                text_to_clickable=text_to_clickable,
            )
        return status
    
    def search(self, keywords):
        """Wrapper for `receive` handler for performing search action on current page"""
        if isinstance(keywords, str):
            keywords = keywords.split(' ')
        self.page_source, self.current_url, status = \
            self.server.receive(
                self.session_id,
                current_url=self.current_url,
                keywords=keywords,
        )
        return status
