"""Script to automatically login each website"""

import os
import subprocess
import sys
from pathlib import Path

if sys.gettrace():
    root_path = Path(__file__).parent.parent  # Go up one level to the project root
    sys.path.insert(0, str(root_path))  # Add project root to sys.path
    from utils.debug_utils import set_env_variables

    set_env_variables(bash_script="scripts/environments/set_env_variables.sh", arg1="local_vwebarena")


import argparse
import glob
import os
import subprocess
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import combinations
from pathlib import Path

from playwright.sync_api import sync_playwright

from browser_env.env_config import ACCOUNTS, CLASSIFIEDS, GITLAB, REDDIT, SHOPPING, SHOPPING_ADMIN

HEADLESS = True
SLOW_MO = 100


SITES = []
URLS = []
EXACT_MATCH = []
KEYWORDS = []

# Main site configurations
if "http:" in SHOPPING or SHOPPING != "PASS":
    SITES.append("shopping")
    URLS.append(f"{SHOPPING}/wishlist/")
    EXACT_MATCH.append(True)
    KEYWORDS.append("")

if "http:" in REDDIT or REDDIT != "PASS":
    SITES.append("reddit")
    URLS.append(f"{REDDIT}/user/{ACCOUNTS['reddit']['username']}/account")
    EXACT_MATCH.append(True)
    KEYWORDS.append("Delete")

if "http:" in CLASSIFIEDS or CLASSIFIEDS != "PASS":
    SITES.append("classifieds")
    URLS.append(f"{CLASSIFIEDS}/index.php?page=user&action=items")
    EXACT_MATCH.append(True)
    KEYWORDS.append("My listings")

# Optional configurations (uncomment to test WebArena tasks)
if "http:" in SHOPPING_ADMIN or SHOPPING_ADMIN != "PASS":
    SITES.append("shopping_admin")
    URLS.append(f"{SHOPPING_ADMIN}/dashboard")
    EXACT_MATCH.append(True)
    KEYWORDS.append("Dashboard")

if "http:" in GITLAB or GITLAB != "PASS":
    SITES.append("gitlab")
    URLS.append(f"{GITLAB}/-/profile")
    EXACT_MATCH.append(True)
    KEYWORDS.append("")

# REVIEW[mandrade]: increased timeout
# Was failing for shopping_admin because of timeout
TIMEOUT = 5000 * 100

assert len(SITES) == len(URLS) == len(EXACT_MATCH) == len(KEYWORDS)

from utils.logger_utils import logger


def is_expired(storage_state: Path, url: str, keyword: str, url_exact: bool = True) -> bool:
    from vwa_utils.vwa_utils import wait_for_page_to_stabilize

    """Test whether the cookie is expired"""
    if not storage_state.exists():
        return True
    context_manager = sync_playwright()
    playwright = context_manager.__enter__()
    browser = playwright.chromium.launch(headless=HEADLESS, slow_mo=SLOW_MO)
    context = browser.new_context(storage_state=storage_state)
    page = context.new_page()
    page.goto(url, timeout=TIMEOUT)
    wait_for_page_to_stabilize(
        page=page,
        logger=logger,
        min_num_trues=3,
        return_after=3,
    )
    d_url = page.url
    content = page.content()
    context_manager.__exit__()
    if keyword:
        return keyword not in content
    else:
        if url_exact:
            return d_url != url
        else:
            return url not in d_url


def renew_comb(comb: list[str], auth_folder: str = "./.auth") -> None:
    context_manager = sync_playwright()
    playwright = context_manager.__enter__()
    browser = playwright.chromium.launch(headless=HEADLESS)
    context = browser.new_context()
    page = context.new_page()

    if "shopping" in comb:
        username = ACCOUNTS["shopping"]["username"]
        password = ACCOUNTS["shopping"]["password"]
        page.goto(f"{SHOPPING}/customer/account/login/", timeout=TIMEOUT)
        page.get_by_label("Email", exact=True).fill(username)
        page.get_by_label("Password", exact=True).fill(password)
        page.get_by_role("button", name="Sign In").click()

    if "reddit" in comb:
        username = ACCOUNTS["reddit"]["username"]
        password = ACCOUNTS["reddit"]["password"]
        page.goto(f"{REDDIT}/login", timeout=TIMEOUT)
        page.get_by_label("Username").fill(username)
        page.get_by_label("Password").fill(password)
        page.get_by_role("button", name="Log in").click()

    if "classifieds" in comb:
        username = ACCOUNTS["classifieds"]["username"]
        password = ACCOUNTS["classifieds"]["password"]
        page.goto(f"{CLASSIFIEDS}/index.php?page=login", timeout=TIMEOUT)
        page.locator("#email").fill(username)
        page.locator("#password").fill(password)
        page.get_by_role("button", name="Log in").click()

    if "shopping_admin" in comb:
        username = ACCOUNTS["shopping_admin"]["username"]
        password = ACCOUNTS["shopping_admin"]["password"]
        # REVIEW[mandrade]: increased timeout
        page.goto(f"{SHOPPING_ADMIN}", timeout=TIMEOUT)
        page.get_by_placeholder("user name").fill(username)
        page.get_by_placeholder("password").fill(password)
        page.get_by_role("button", name="Sign in").click()

    if "gitlab" in comb:
        username = ACCOUNTS["gitlab"]["username"]
        password = ACCOUNTS["gitlab"]["password"]
        page.goto(f"{GITLAB}/users/sign_in")
        page.get_by_test_id("username-field").click()
        page.get_by_test_id("username-field").fill(username)
        page.get_by_test_id("username-field").press("Tab")
        page.get_by_test_id("password-field").fill(password)
        page.get_by_test_id("sign-in-button").click()

    context.storage_state(path=f"{auth_folder}/{'.'.join(comb)}_state.json")
    context_manager.__exit__()


def get_site_comb_from_filepath(file_path: str) -> list[str]:
    comb = os.path.basename(file_path).rsplit("_", 1)[0].split(".")
    return comb


def get_cookie_paths_for_site(
    site: str,
    all_sites: list[str] = SITES,
    auth_folder: str = "./.auth",
    exc_comb: bool = False,
) -> list[str]:
    all_cookies = []

    all_cookies.append(f"{auth_folder}/{site}_state.json")
    if exc_comb:
        return all_cookies

    # All possible combinations:
    all_pairs = list(combinations(all_sites, 2))
    # Keep only the pairs whose one of the sites is in the sites list
    pairs = [pair for pair in all_pairs if site in pair]
    for pair in pairs:
        if "reddit" in pair and ("shopping" in pair or "shopping_admin" in pair):
            continue
        all_cookies.append(f"{auth_folder}/{'.'.join(sorted(pair))}_state.json")
    return all_cookies


def is_expired_for_sites(sites: list[str] | str, auth_folder: str = "./.auth", exc_comb=False) -> dict[str, list[str]]:
    if isinstance(sites, str):
        sites = [sites]

    site_to_cookies = {
        site: get_cookie_paths_for_site(site, exc_comb=exc_comb, auth_folder=auth_folder) for site in sites
    }

    expired_site_to_cookies = {}
    num_tasks = sum(len(cookies) for cookies in site_to_cookies.values())

    if num_tasks > 1:
        with ThreadPoolExecutor(max_workers=8) as executor:
            # Map each submitted future to its (site, cookie file) information
            future_to_info = {}
            for site, cookies in site_to_cookies.items():
                for c_file in cookies:
                    site_comb = get_site_comb_from_filepath(c_file)
                    for cur_site in site_comb:
                        future = executor.submit(
                            is_expired,
                            Path(c_file),
                            URLS[SITES.index(cur_site)],
                            KEYWORDS[SITES.index(cur_site)],
                            EXACT_MATCH[SITES.index(cur_site)],
                        )
                        future_to_info[future] = (site, c_file)

            for future in future_to_info:
                site, c_file = future_to_info[future]
                if future.result():
                    expired_site_to_cookies.setdefault(site, [])
                    if c_file not in expired_site_to_cookies[site]:
                        expired_site_to_cookies[site].append(c_file)
        return expired_site_to_cookies
    else:
        for site, cookies in site_to_cookies.items():
            for c_file in cookies:
                site_comb = get_site_comb_from_filepath(c_file)
                for cur_site in site_comb:
                    if is_expired(
                        Path(c_file),
                        URLS[SITES.index(cur_site)],
                        KEYWORDS[SITES.index(cur_site)],
                        EXACT_MATCH[SITES.index(cur_site)],
                    ):
                        expired_site_to_cookies.setdefault(site, [])
                        if c_file not in expired_site_to_cookies[site]:
                            expired_site_to_cookies[site].append(c_file)
        return expired_site_to_cookies


def main(sites: list[str], auth_folder: str = "./.auth", exc_comb: bool = False) -> None:
    if not exc_comb:
        # All possible combinations:
        all_pairs = list(combinations(SITES, 2))
        # Keep only the pairs whose one of the sites is in the sites list
        pairs = [pair for pair in all_pairs if any(site in pair for site in sites)]
    else:
        print("[INFO] Cookie creation: Excluding combinations of sites.")
        pairs = []

    with ThreadPoolExecutor(max_workers=8) as executor:
        for pair in pairs:
            # Auth doesn't work on this pair as they share the same cookie
            if "reddit" in pair and ("shopping" in pair or "shopping_admin" in pair):
                continue
            executor.submit(renew_comb, list(sorted(pair)), auth_folder=auth_folder)

        for site in sites:
            executor.submit(renew_comb, [site], auth_folder=auth_folder)

    # parallel checking if the cookies are expired
    futures = []
    cookie_files = list(glob.glob(f"{auth_folder}/*.json"))
    with ThreadPoolExecutor(max_workers=8) as executor:
        for c_file in cookie_files:
            comb = get_site_comb_from_filepath(c_file)
            if exc_comb and len(comb) > 1:
                continue
            for cur_site in comb:
                if exc_comb and cur_site not in sites:
                    continue
                url = URLS[SITES.index(cur_site)]
                keyword = KEYWORDS[SITES.index(cur_site)]
                match = EXACT_MATCH[SITES.index(cur_site)]
                future = executor.submit(is_expired, Path(c_file), url, keyword, match)
                futures.append(future)

    for i, future in enumerate(futures):
        assert not future.result(), f"Cookie {cookie_files[i]} expired."


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--exc_comb", action="store_true", default=False)
    parser.add_argument("--site_list", nargs="+", default=SITES)
    parser.add_argument("--auth_folder", type=str, default="./.auth")
    args = parser.parse_args()

    if sys.gettrace():
        args.auth_folder = "./.auth"
        args.exc_comb = False
        args.site_list = ["reddit"]

    os.makedirs(args.auth_folder, exist_ok=True)
    main(auth_folder=args.auth_folder, sites=args.site_list, exc_comb=args.exc_comb)
