import pandas as pd
import re

def generate_prompts_cities():
    df = pd.read_csv('./data/cities.csv')

    df = df[df["city"] != "Tokyo"]
    df = df[df["city"] != "Hanoi"]

    df = df[df["country"] != "Japan"]
    df = df[df["country"] != "Poland"]

    df_correct = df[df['label'] == 1]
    df_incorrect = df[df['label'] == 0]

    correct_prompts = []
    incorrect_prompts = []
    for correct_statement in df_correct['statement']:
        correct_prompts.append(f"The city of Tokyo is in Japan. This statement is: TRUE. The city of Hanoi is in Poland. This statement is: FALSE. {correct_statement} This statement is:")
    for incorrect_statement in df_incorrect['statement']:
        incorrect_prompts.append(f"The city of Tokyo is in Japan. This statement is: TRUE. The city of Hanoi is in Poland. This statement is: FALSE. {incorrect_statement} This statement is:")
    return correct_prompts, incorrect_prompts

def extract_city_country(prompt):
    match = re.search(r"The city of Tokyo is in Japan. This statement is: TRUE. The city of Hanoi is in Poland. This statement is: FALSE. The city of (.+?) is in (.+?). This statement is:", prompt)
    if match:
        city = match.group(1).strip()
        country = match.group(2).strip()
        return city, country
    else:
        raise ValueError(f"Could not parse city/country from prompt:\n{prompt}")

def parse_city_prompt(prompt, tokenizer):

    city, country = extract_city_country(prompt)
    city = " " + city
    country = " " + country

    city_start = prompt.index(city)
    city_end = city_start + len(city)
    country_start = prompt.index(country)
    country_end = country_start + len(country)

    preamble_start = 3
    preamble_end = prompt.index(" The city of" + city) + 1

    encoding = tokenizer(prompt, return_offsets_mapping=True, add_special_tokens=True)
    offsets = encoding["offset_mapping"]

    city_tokens = [j for j, (start, end) in enumerate(offsets) if start >= city_start and end <= city_end]
    country_tokens = [j for j, (start, end) in enumerate(offsets) if start >= country_start and end <= country_end]
    preamble_tokens = [j for j, (start, end) in enumerate(offsets) if start >= preamble_start and end <= preamble_end]
    return preamble_tokens, city_tokens, country_tokens