import argparse
import os
import pandas as pd
import logging
import re
from glob import glob

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

# Constants
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
INPUT_ROOT = os.path.join(PROJECT_ROOT, "result", "eval", "human_llm")
VALID_INFO_PATH = os.path.join(PROJECT_ROOT, "data", "table_file_info_20250710.csv") # import info of valid datasets
VALID_OUTPUT_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "preprocessed_depth_valid.csv")
OUTPUT_PATH = os.path.join(PROJECT_ROOT, "result", "group_level_eval", "preprocessed_depth.csv")
TARGET_FILENAME = "opinion_memory_gpt-4o-mini-2024-07-18_v0.csv"
DEPTH_TOPICS = {
    'A_"body_cleanse,"_in_which_you_consume_only_particular_kinds_of_nutrients_over_1-3_days,_helps_your_body_to_eliminate_toxins',
    "Angels_are_real",
    "The_United_States_has_the_highest_federal_income_tax_rate_of_any_Western_country",
    "Everything_that_happens_can_eventually_be_explained_by_science",
    "Regular_fasting_will_improve_your_health",
    "The_US_deficit_increased_after_President_Obama_was_elected",
    "The_position_of_the_planets_at_the_time_of_your_birth_can_influence_your_personality"
}

AGREEMENT_TO_LIEKRT_VALUE = {
    "Certainly disagree": 1,
    "Probably disagree": 2,
    "Lean disagree": 3,
    "Lean agree": 4,
    "Probably agree": 5, 
    "Certainly agree": 6,
}

def extract_info_from_exp_dir(exp_dir):
    match = re.match(r"(\d{8}_\d{6})_(.+)_01[A-Z0-9]+$", exp_dir)
    if not match:
        raise ValueError(f"exp_dir format invalid: {exp_dir}")
    time_stamp = match.group(1)
    topic = match.group(2)
    return time_stamp, topic

# helper func to extract valid filenames from table file
def get_valid_filenames():
    try: 
        df_info = pd.read_csv(VALID_INFO_PATH)
        df_valid = df_info[
            (df_info["num_players"] == 4) &
            (df_info["is_fully_complete"] == True) &
            (df_info["source"] == "prolific")
        ]

        def extract_exp_dir(name):
            name = re.sub(r"_0\.0\.1\.csv$", "", name)
            return name.replace(".csv", "")
        
        return set(df_valid["csv_filename"].apply(extract_exp_dir).tolist())
    
    except Exception as e:
        logging.warning(f"Error loading valid dataset info: {e}")
        return set()

def get_ft_type(model_name):
    if model_name in {"gpt-4o-mini-2024-07-18", "Llama-3.1-8B-Instruct"}:
        return "base"
    elif model_name in {
        "ft:gpt-4o-mini-2024-07-18:camer:round-split-all:BOvS862Y",
        "ft:gpt-4o-mini-2024-07-18:camer:round-split-valid:BRTJQLtG",
        "ft:Llama-3.1-8B-Instruct:round-split-valid-1epoch",
        "ft:Llama-3.1-8B-Instruct:round-split-valid-5epochs"
    }:
        return "round"
    elif model_name == "ft:gpt-4o-mini-2024-07-18:camer:topic-split-all:BOqtcdMB":
        return "topic"
    elif model_name == "ft:gpt-4o-mini-2024-07-18:camer:group-split-all:BOvZjzvU":
        return "group"
    else:
        return "other"

def preprocess_depth():
    rows = []
    rows_valid = []
    topics_set = set()  # <- collect unique topics here

    exp_dirs = os.listdir(INPUT_ROOT)
    valid_exp_dirs = get_valid_filenames()

    for exp_dir in exp_dirs:
        exp_path = os.path.join(INPUT_ROOT, exp_dir)
        if not os.path.isdir(exp_path):
            continue

        try:
            time_stamp, topic = extract_info_from_exp_dir(exp_dir)

            # ✅ Only include depth topics
            if topic not in DEPTH_TOPICS:
                continue

            topics_set.add(topic)

        except ValueError as e:
            logging.warning(str(e))
            continue

        for model_name in os.listdir(exp_path):
            model_path = os.path.join(exp_path, model_name)
            csv_path = os.path.join(model_path, TARGET_FILENAME)
            if not os.path.isfile(csv_path):
                logging.warning(f"Missing file: {csv_path}")
                continue

            try:
                df = pd.read_csv(csv_path)
                df = df[df["event_type"].isin(["Initial Opinion", "tweet", "Post Opinion"])]
                if df.empty:
                    continue

                tweet_partners = df[(df["event_type"] == "tweet") & (df["recipient_id"].notna())]
                partner_map = {
                    (row["worker_id"], row["chat_round_order"]): row["recipient_id"]
                    for _, row in tweet_partners.iterrows()
                }

                for _, row in df.iterrows():
                    event_type = row.get("event_type")
                    # use_slider = event_type in ["Initial Opinion", "Post Opinion"]
                    human_id = row.get("worker_id")

                    if event_type == "Initial Opinion": 
                        llm_slider_val = row.get("sliderValue")
                    elif event_type == "Post Opinion":
                        agreement = row.get("agreement_level")
                        llm_slider_val = AGREEMENT_TO_LIEKRT_VALUE.get(agreement, None)
                    else: 
                        llm_slider_val = row.get("llm_likert_pred")
                    
                    record = {
                        "topic": topic,
                        "exp_dir": exp_dir,
                        "time_stamp": time_stamp,
                        "type": "TBD",
                        "model_name": model_name,
                        "ft_type": get_ft_type(model_name),
                        "event_type": event_type,
                        "chat_order": row.get("chat_round_order"),
                        "human_id": row.get("worker_id"),
                        "partner1_id": partner_map.get((human_id, 1), None),
                        "partner2_id": partner_map.get((human_id, 2), None),
                        "partner3_id": partner_map.get((human_id, 3), None),
                        "human_likert_pred": row.get("human_likert_pred"),
                        "llm_likert_pred": row.get("llm_likert_pred"),
                        "human_slider": row.get("sliderValue") if event_type in ["Initial Opinion", "Post Opinion"] else row.get("human_likert_pred"),
                        "llm_slider": llm_slider_val,
                        "human_text": row.get("text"),
                        "llm_text": row.get("llm_text")
                    }
                    rows.append(record)
                    if exp_dir in valid_exp_dirs:
                        rows_valid.append(record)

            except Exception as e:
                logging.warning(f"Error reading {csv_path}: {e}")
                continue

    # Save preprocessed data
    df_all = pd.DataFrame(rows)
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    df_all.to_csv(OUTPUT_PATH, index=False)
    logging.info(f"Saved {len(df_all)} rows to {OUTPUT_PATH}")

    df_valid = pd.DataFrame(rows_valid)
    df_valid.to_csv(VALID_OUTPUT_PATH, index=False)
    logging.info(f"Saved {len(df_valid)} valid rows to {VALID_OUTPUT_PATH}")

    # Save extracted topics
    topic_list_path = os.path.join(os.path.dirname(VALID_OUTPUT_PATH), "preprocessed_depth_topics.txt")
    with open(topic_list_path, "w") as f:
        for topic in sorted(topics_set):
            f.write(topic + "\n")
    logging.info(f"Saved {len(topics_set)} unique topics to {topic_list_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", required=True, choices=["depth"], help="Preprocessing mode")
    args = parser.parse_args()

    if args.mode == "depth":
        preprocess_depth()
    else:
        raise NotImplementedError("Only 'depth' mode is supported for now.")

if __name__ == "__main__":
    main()