import os
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModel
from langchain_openai import ChatOpenAI
import matplotlib.pyplot as plt
import numpy as np
import argparse
import re

"""
====================================
Semantic Similarity & Opinion Analysis
====================================
This script processes, evaluates, and visualizes semantic similarity and opinion trajectories 
from an augmented dataset containing both original and paraphrased text.

It includes:
1. Semantic Similarity Calculation - Computes similarity scores between original and paraphrased text.
2. Sentiment & Likert Score Prediction - Assigns Likert values (-2 to 2) based on sentiment classification.
3. Data Visualization - Plots similarity distributions, similarity trajectories, and opinion trajectories.
4. Bias & Variability Analysis - Computes Change in Bias and Change in Standard Deviation across rounds.

-----------------------------------
USAGE INSTRUCTIONS:
-----------------------------------
This script supports partial execution (for debugging) using command-line arguments. 
You can run specific tasks instead of executing the entire script.

Usage:
    python aug_data_semantic_similarity.py <task>

Available tasks:
    - `all`        → Run all tasks (default).
    - `evaluate`   → Compute similarity scores and predict Likert values (stores results in CSV).
    - `similarity` → Plot the semantic similarity trajectory (uses precomputed results).
    - `opinion`    → Plot opinion trajectories for both original and paraphrased text (uses precomputed results).
    - `bias`       → Compute Bias and Standard Deviation change across rounds.

Examples:
    1. Run everything:
        python aug_data_semantic_similarity.py all

    2. Compute similarity & Likert scores without plotting:
        python aug_data_semantic_similarity.py evaluate

    3. Generate only the similarity trajectory plot:
        python aug_data_semantic_similarity.py similarity

    4. Generate only the opinion trajectory plots:
        python aug_data_semantic_similarity.py opinion

    5. Compute only the bias and standard deviation change:
        python aug_data_semantic_similarity.py bias

-----------------------------------
DEPENDENCIES:
-----------------------------------
Ensure the following Python packages are installed before running:
- `transformers`
- `sentence-transformers`
- `langchain_openai`
- `matplotlib`
- `numpy`
- `pandas`

To install missing dependencies, run:
    pip install -r requirements.txt

-----------------------------------
NOTES:
-----------------------------------
- If you run `similarity` or `opinion`, the script will load precomputed results instead of recomputing everything.
- Running `evaluate` will overwrite the existing CSV file with new similarity and sentiment predictions.

"""


data_prefix = "20241028_153927_A__body_cleanse,__in_which_you_consume_only_particular_kinds_of_nutrients_over_1-3_days,_helps_your_body_to_eliminate_toxins_01JB9V4TTHV4FRSNK02H14T0X8"
#input_file = f"../../data/augmented_data/{data_prefix}/{data_prefix}_merged.csv"
augmented_file = f"../../data/augmented_data/{data_prefix}/{data_prefix}_augmented.csv"
output_path = f"../../result/eval/augment_data/{data_prefix}/"
output_image = os.path.join(output_path, f"{data_prefix}_similarity_trajectory.png")

# load the pre-trained SentenceTransformer model
# similarity_model = SentenceTransformer('all-MiniLM-L6-v2')
# similarity_model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True, device_map="cuda")
similarity_model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", trust_remote_code=True).to("cpu")


llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini", temperature=0.7, max_tokens=300)

# a mapping of labels and likert values
labels_to_likert = {
    "Certainly Agree": 6,
    "Probably Agree": 5,
    "Lean Agree": 4,
    "Lean Disagree": 3,
    "Probably Disagree": 2,
    "Certainly Disagree": 1
}

def extract_topic(data_prefix):
    """
    Extracts the topic from the data_prefix.
    Assumes the format: <timestamp>_<topic>_<unique_id>.
    """
    parts = data_prefix.split("_")
    topic = "_".join(parts[2:-1])  # extract everything between the timestamp and the unique ID
    return topic.replace("__", " ").replace("_", " ")  # format for readability

def extract_label_and_likert(text):
    """
    Extracts the label and likert value from the text.
    """
    try:
        # extract the label from the text
        match = re.search(r"\[SLIDER_VALUE=(\d+)\]", text)
        if match:
            likert_value = int(match.group(1))
            text_cleaned = text[:match.start()].strip()
            return text_cleaned, likert_value
        return text, None
    except Exception as e:
        print(f"Error in extracting label and likert: {e}")
        return None, 0

def generate_likert_label_prompt(topic, conversation_context, new_message, sender):
    """
    Generates a structured Likert label classification prompt with improved formatting.
    """

    # format conversation history as structured dialogue
    conversation_history = "\n".join(
        [f"- {entry[1]}: {entry[2]}" for entry in conversation_context if len(entry) == 3]
    )

    prompt = f"""
Here is the discussion between {sender} and the other participant so far:
{conversation_history}

What is the stance of the following new message by {sender} with respect to the topic "{topic}"?
Here is the message by {sender}: {new_message}

Please use exactly one of the following 6 categories to label the stance: 
- "Certainly Agree"
- "Probably Agree"
- "Lean Agree"
- "Lean Disagree"
- "Probably Disagree"
- "Certainly Disagree"

Your predicted stance:
"""
    return prompt


def classify_likert_label(text, topic, sender, prompt):
    """
    Classifies the likert labels of the message text based on the context of the conversation.
    """
    try:
        messages = [
            {
                "role": "system",
                "content": (
                    "You are an assistant that classifies the sentiment of a message based on prior conversation context. "
                    "Please provide only the likert label based on the context provided."
                ),
            },
            {
                "role": "user",
                "content": prompt, 
            },
        ]
        
        response = llm.invoke(messages)
        sentiment_label = response.content.strip()

        sentiment_label = sentiment_label.replace('"', '').replace("'", "").strip()  # remove quotes
        sentiment_label = re.sub(r"[^a-zA-Z\s]", "", sentiment_label)  # remove unexpected characters
        sentiment_label = sentiment_label.strip()
        
        # ensure the response is a valid sentiment label
        if sentiment_label in labels_to_likert:
            sentiment_value = labels_to_likert[sentiment_label]
            return sentiment_label, sentiment_value
        
        # default to "Neutral" if an invalid label is returned
        return sentiment_label, 3.5
    except Exception as e:
        print(f"Error in sentiment classification: {e}")
        return "Error", 0

def predict(df, topic):
    """
    Predicts label for 'text' only and appends 'text_label' and 'text_likert' to the data frame.
    """
    try:
        df = df[~df["event_type"].isin(["idle", "exit_survey"])]  # filter unnecessary rows
        df["text_label"] = ""
        df["text_likert"] = 0
        df["prompt_text"] = ""
        df["orig_text_label"] = ""
        df["orig_text_likert"] = 0
        df["prompt_orig_text"] = ""
        id_mapping = dict(zip(df["empirica_id"], df["worker_id"]))
        message_histories = {}
        message_histories_orig = {}

        # step 1: predict for 'text' first
        for index, row in df.iterrows():
            sender = row["sender_id"] if pd.notna(row["sender_id"]) else row["empirica_id"]
            sender = id_mapping.get(sender, "Unknown Speaker") 
            # for normal messages, get recipient_id
            recipient = row['recipient_id'] if pd.notna(row["recipient_id"]) else row["empirica_id"]
            recipient = id_mapping.get(recipient, "Unknown Speaker")
            event_type = row["event_type"]
            text = row["text"]
            round_order = row["chat_round_order"] if pd.notna(row["chat_round_order"]) else 0

            if sender not in message_histories:
                message_histories[sender] = []
            if recipient not in message_histories:
                message_histories[recipient] = []
            
            context = message_histories[sender][-5:]
            prompt = generate_likert_label_prompt(topic, context, text, sender)

            if event_type in ["tweet", "message_sent", "message_recieved"]:
                sentiment_label, sentiment_value = classify_likert_label(text, topic, sender, prompt)
                df.at[index, "text_label"] = sentiment_label
                df.at[index, "text_likert"] = sentiment_value
                df.at[index, "prompt_text"] = prompt

                # update conversation context
                message_histories[sender].append((round_order, sender, text))
                message_histories[recipient].append((round_order, sender, text))
            
            elif event_type in ["Initial Opinion", "Post Opinion"]:
                sentiment_label, sentiment_value = classify_likert_label(text, topic, sender, prompt)
                df.at[index, "text_label"] = sentiment_label
                df.at[index, "text_likert"] = sentiment_value
                df.at[index, "prompt_text"] = prompt

        # step 2: predict for 'orig_text'
        for index, row in df.iterrows():
            sender = id_mapping.get(row["sender_id"], "Unknown Speaker")
            recipient = id_mapping.get(row["recipient_id"], "Unknown Speaker")
            event_type = row["event_type"]
            orig_text = row["orig_text"]
            round_order = row["chat_round_order"] if pd.notna(row["chat_round_order"]) else 0

            if sender not in message_histories_orig:
                message_histories_orig[sender] = []
            if recipient not in message_histories_orig:
                message_histories_orig[recipient] = []

            context = message_histories_orig[sender][-5:]
            prompt = generate_likert_label_prompt(topic, context, orig_text, sender)

            if event_type in ["tweet", "message_sent", "message_received"]:
                sentiment_label, sentiment_value = classify_likert_label(orig_text, topic, sender, prompt)
                df.at[index, "orig_text_label"] = sentiment_label
                df.at[index, "orig_text_likert"] = sentiment_value
                df.at[index, "prompt_orig_text"] = prompt

                # Update conversation history
                message_histories_orig[sender].append((round_order, sender, orig_text))
                message_histories_orig[recipient].append((round_order, sender, orig_text))

            elif event_type in ["Initial Opinion", "Post Opinion"]:
                # Use values from raw dataset instead of predicting
                raw_data_path = f"../../data/raw_data/{data_prefix}.csv"
                if not os.path.exists(raw_data_path):
                    print(f"Error: Original experimental data not found at {raw_data_path}")
                    return df

                raw_df = pd.read_csv(raw_data_path)
                raw_df = raw_df[raw_df["event_type"].isin(["Initial Opinion", "Post Opinion"])]

                # Extract original labels and Likert values
                extracted_labels = {}
                for _, row in raw_df.iterrows():
                    cleaned_text, likert_value = extract_label_and_likert(row["text"])
                    if likert_value is not None:
                        extracted_labels[cleaned_text] = likert_value

                if orig_text in extracted_labels:
                    df.at[index, "orig_text_label"] = next(
                        (label for label, value in labels_to_likert.items() if value == extracted_labels[orig_text]), 
                        "Unknown"
                    )
                    df.at[index, "orig_text_likert"] = extracted_labels[orig_text]
        
        return df
    except Exception as e:
        print(f"Error in likert classification pipeline: {e}")
        return pd.DataFrame()

def calculate_similarity(original_text, paraphrased_text):
    """
    Computes the semantic similarity between the original and paraphrased text.
    """
    try:
        embeddings = similarity_model.encode([original_text, paraphrased_text], convert_to_tensor=True)
        similarity_score = util.cos_sim(embeddings[0], embeddings[1])
        return similarity_score.item()
    except Exception as e:
        print(f"Error calculating similarity: {e}")
        return None

def filter_data(df, exclude_event_types):
    """
    Filters out rows with specific event types.
    """
    return df[~df["event_type"].isin(exclude_event_types)]
    
def evaluate_similarity(augmented_file):
    """
    Evaluates semantic similarity between 'orig_text' and 'text' columns in the augmented dataset.
    """
    try:
        # Load the augmented dataset
        df = pd.read_csv(augmented_file)

        # ensure required columns exist
        if "orig_text" not in df.columns or "text" not in df.columns or "chat_round_order" not in df.columns:
            raise ValueError("The dataset must contain 'orig_text', 'text', and 'chat_round_order' columns.")

        # calculate similarity row by row
        df["sem_score"] = df.apply(lambda row: calculate_similarity(row["orig_text"], row["text"]), axis=1)
        # assign time steps per worker_id, with increments of 5
        df["time_step"] = df.groupby("worker_id").cumcount() * 5

        return df
    except Exception as e:
        print(f"Error in evaluating similarity: {e}")
        return pd.DataFrame()

def summarize_similarity(df, threshold=0.8):
    """
    Summarizes the semantic similarity evaluation.
    """
    mean_similarity = df["sem_score"].mean()
    pass_rate = (df["sem_score"] > threshold).mean() * 100
    print(f"Mean Semantic Similarity: {mean_similarity:.4f}")
    print(f"Percentage Above Threshold ({threshold}): {pass_rate:.2f}%")

def plot_similarity_distribution(df, output_dir):
    """
    Plots the distribution of semantic similarity scores and saves it to the output directory.
    """
    plt.hist(df["sem_score"].dropna(), bins=20, alpha=0.7)
    plt.title("Semantic Similarity Distribution")
    plt.xlabel("Similarity Score")
    plt.ylabel("Frequency")
    output_file = os.path.join(output_dir, "similarity_distribution.png")
    plt.savefig(output_file)
    plt.close()
    print(f"Distribution plot saved to '{output_file}'.")

def plot_similarity_trajectory(df, output_image):
    """
    Plots the trajectory of semantic similarity scores across chat order, grouped by worker.
    Adds dashed lines when transitioning between rounds and labels each round transition.
    """
    plt.figure(figsize=(12, 6))

    # track individual worker's time step
    df["time_step"] = df.groupby("worker_id").cumcount()

    # extract unique workers
    workers = df["worker_id"].unique()

    for worker in workers:
        worker_df = df[df["worker_id"] == worker]

        # plot similarity trajectory for each worker
        plt.plot(worker_df["time_step"], worker_df["sem_score"], marker="o", linestyle="-", label=worker)

    # identify round transitions (where chat_round_order changes)
    previous_round = None
    round_labels = {1: "Round 1", 2: "Round 2", 3: "Round 3"}  # labels for rounds
    y_max = 1.05  # adjusted for visibility above max similarity score

    transition_positions = []  # store positions of transition points

    for index, row in df.iterrows():
        current_round = row["chat_round_order"]
        if previous_round is not None and not pd.isna(current_round):
            if previous_round != current_round:  # transition detected
                plt.axvline(x=row["time_step"], color="red", linestyle="dashed", alpha=0.5)
                transition_positions.append(row["time_step"])  # store transition positions
        
        previous_round = current_round

    # add round labels above dashed lines
    if len(transition_positions) > 0:
        for i in range(len(transition_positions) - 1):
            midpoint = (transition_positions[i] + transition_positions[i + 1]) / 2
            round_number = i + 1  # Round 1, Round 2, etc.
            if round_number in round_labels:
                plt.text(midpoint, y_max, round_labels[round_number], fontsize=12, color="black", ha="center", fontweight="bold")

        # ensure Round 3 is labeled
        if 3 in df["chat_round_order"].values:
            last_transition = transition_positions[-1] + 2  # slightly offset
            plt.text(last_transition, y_max, "Round 3", fontsize=12, color="black", ha="center", fontweight="bold")

    # fix X-axis to show only multiples of 5
    max_time_step = df["time_step"].max()
    plt.xticks(np.arange(0, max_time_step + 1, step=5))

    # titles & labels
    plt.suptitle("Paraphrased vs. Original Similarity", fontsize=16, fontweight="bold", y=1.02)  # Move title above round labels
    plt.xlabel("Time Step (Chat Order)")
    plt.ylabel("Similarity Score")
    plt.legend(fontsize="x-small")
    plt.grid(True)
    plt.ylim(0.6, 1.05)

    os.makedirs(output_path, exist_ok=True)
    plt.savefig(output_image, bbox_inches="tight")
    plt.close()
    print(f"Similarity trajectory plot saved to '{output_image}'.")

def plot_opinion_trajectory(df, output_path, data_prefix):
    """
    Plots two separate opinion trajectories: 
    - One for original text (`orig_text_likert`).
    - One for paraphrased text (`text_likert`).
    """
    os.makedirs(output_path, exist_ok=True)

    df["likert_diff"] = df["orig_text_likert"] - df["text_likert"]

    def plot_single_trajectory(df, likert_column, title, output_filename):
        plt.figure(figsize=(12, 6))

        # track individual worker's time step
        df["time_step"] = df.groupby("worker_id").cumcount()

        # extract unique workers
        workers = df["worker_id"].unique()

        for worker in workers:
            worker_df = df[df["worker_id"] == worker]
            plt.plot(worker_df["time_step"], worker_df[likert_column], marker="o", linestyle="-", label=worker)

        # identify round transitions
        previous_round = None
        round_labels = {1: "Round 1", 2: "Round 2", 3: "Round 3"}  # labels for rounds
        # dynamically adjust y_max based on data
        y_max = df[likert_column].max() + 0.8
        y_min = df[likert_column].min() - 0.2

        transition_positions = []  # store positions of transition points

        for index, row in df.iterrows():
            current_round = row["chat_round_order"]
            if previous_round is not None and not pd.isna(current_round):
                if previous_round != current_round:  # transition detected
                    plt.axvline(x=row["time_step"], color="red", linestyle="dashed", alpha=0.5)
                    transition_positions.append(row["time_step"])  # store transition positions

            previous_round = current_round

        # label the rounds
        if len(transition_positions) > 0:
            for i in range(len(transition_positions) - 1):
                midpoint = (transition_positions[i] + transition_positions[i + 1]) / 2
                round_number = i + 1  # Round 1, Round 2, etc.
                if round_number in round_labels:
                    plt.text(midpoint, y_max - 0.5, round_labels[round_number], fontsize=12, 
                             color="black", ha="center", fontweight="bold",
                             bbox=dict(facecolor="white", alpha=0.7))

            # ensure Round 3 is labeled
            if 3 in df["chat_round_order"].values:
                last_transition = transition_positions[-1] + 2  # slightly offset
                plt.text(last_transition, y_max - 0.5, "Round 3", fontsize=12, 
                         color="black", ha="center", fontweight="bold",
                         bbox=dict(facecolor="white", alpha=0.7))

        # X-axis should be in intervals of 5
        max_time_step = df["time_step"].max()
        plt.xticks(np.arange(0, max_time_step + 1, step=5))

        # titles & labels
        plt.title(title, fontsize=14, pad=30)
        plt.xlabel("Time Step (Chat Order)")
        plt.ylabel("Likert Value (1 to 6)")
        plt.legend(fontsize="x-small")
        plt.grid(True)
        plt.ylim(y_min, y_max)  # ynamically adjust y-axis range

        output_image = os.path.join(output_path, output_filename)
        plt.savefig(output_image)
        plt.close()
        print(f"Opinion trajectory plot saved to '{output_image}'.")

    # plot original text opinion trajectory
    plot_single_trajectory(
        df, "orig_text_likert",
        "Original Text Opinion Trajectory",
        f"{data_prefix}_orig_opinion_trajectory.png"
    )

    # plot paraphrased text opinion trajectory
    plot_single_trajectory(
        df, "text_likert",
        "Paraphrased Text Opinion Trajectory",
        f"{data_prefix}_paraphrased_opinion_trajectory.png"
    )

    # plot difference between original and paraphrased opinion trajectories
    plot_single_trajectory(
        df, "likert_diff",
        "Original vs Paraphrased Text Opinion Trajectory",
        f"{data_prefix}_opinion_diff_trajectory.png"
    )

def compute_bias_and_variability(df, output_csv):
    """
    Computes the Change in Bias and Change in Standard Deviation across rounds.
    Saves the results into a CSV file.
    """
    try:
        # ensure required columns exist
        required_columns = ["worker_id", "chat_round_order", "text_likert", "orig_text_likert"]
        for col in required_columns:
            if col not in df.columns:
                raise ValueError(f"Missing required column: {col}")
        
        initial_opinion = df[df["chat_round_order"] == 1]
        final_opinion = df[df["chat_round_order"] == 3]

        # compute mean and std for Initial & Final Opinions across players
        text_likert_mean_initial = initial_opinion["text_likert"].mean()
        text_likert_mean_final = final_opinion["text_likert"].mean()
        orig_text_likert_mean_initial = initial_opinion["orig_text_likert"].mean()
        orig_text_likert_mean_final = final_opinion["orig_text_likert"].mean()

        text_likert_std_initial = initial_opinion["text_likert"].std()
        text_likert_std_final = final_opinion["text_likert"].std()
        orig_text_likert_std_initial = initial_opinion["orig_text_likert"].std()
        orig_text_likert_std_final = final_opinion["orig_text_likert"].std()

        # compute changes
        overall_bias_change_text = text_likert_mean_final - text_likert_mean_initial
        overall_bias_change_orig_text = orig_text_likert_mean_final - orig_text_likert_mean_initial
        overall_std_change_text = text_likert_std_final - text_likert_std_initial
        overall_std_change_orig_text = orig_text_likert_std_final - orig_text_likert_std_initial

        # Store results in a dictionary
        overall_results = pd.DataFrame({
            "Metric": ["Bias Change", "Std Change"],
            "Paraphrased Text": [overall_bias_change_text, overall_std_change_text],
            "Original Text": [overall_bias_change_orig_text, overall_std_change_orig_text]
        })

        # groul level individual work bias & variability
        # group by worker and round, compute mean and std
        grouped = df.groupby(["worker_id", "chat_round_order"]).agg(
            text_likert_mean=("text_likert", "mean"),
            text_likert_std=("text_likert", "std"),
            orig_text_likert_mean=("orig_text_likert", "mean"),
            orig_text_likert_std=("orig_text_likert", "std")
        ).reset_index()

        # pivot the data to get Round 1, Round 2, Round 3 values in separate columns
        pivoted = grouped.pivot(index="worker_id", columns="chat_round_order")

        # flatten multi-index columns
        pivoted.columns = [f"{metric}_round{int(round_num)}" for metric, round_num in pivoted.columns]
        pivoted.reset_index(inplace=True)

        # compute change in bias and change in standard deviation
        pivoted["text_likert_bias_change"] = pivoted["text_likert_mean_round3"] - pivoted["text_likert_mean_round1"]
        pivoted["orig_text_likert_bias_change"] = pivoted["orig_text_likert_mean_round3"] - pivoted["orig_text_likert_mean_round1"]
        pivoted["text_likert_sd_change"] = pivoted["text_likert_std_round3"] - pivoted["text_likert_std_round1"]
        pivoted["orig_text_likert_sd_change"] = pivoted["orig_text_likert_std_round3"] - pivoted["orig_text_likert_std_round1"]

        pivoted.to_csv(output_csv, index=False)
        overall_csv = output_csv.replace(".csv", "_overall.csv")
        overall_results.to_csv(overall_csv, index=False)

        print(f"Bias and variability results saved to '{output_csv}'.")
        print(f"Overall bias and variability results saved to '{overall_csv}'.")

    except Exception as e:
        print(f"Error in computing bias and variability: {e}")

def plot_bias_and_variability(df, overall_bias_df, output_path, data_prefix):
    """
    Plots bias change and standard deviation change across rounds for both original and paraphrased text.
    """
    os.makedirs(output_path, exist_ok=True)

    def plot_single_metric(df, metric_column, title, output_filename, y_label):
        plt.figure(figsize=(12, 6))

        # extract unique workers
        workers = df["worker_id"].unique()

        for worker in workers:
            worker_df = df[df["worker_id"] == worker]
            plt.plot([1, 2, 3], 
                     [worker_df[f"{metric_column}_round1"].values[0], 
                      worker_df[f"{metric_column}_round2"].values[0], 
                      worker_df[f"{metric_column}_round3"].values[0]], 
                     marker="o", linestyle="-", label=worker)

        # labels and titles
        plt.xticks([1, 2, 3], ["Round 1", "Round 2", "Round 3"])
        plt.title(title, fontsize=14, pad=20)
        plt.xlabel("Round")
        plt.ylabel(y_label)
        plt.legend(fontsize="x-small")
        plt.grid(True)

        output_image = os.path.join(output_path, output_filename)
        plt.savefig(output_image)
        plt.close()
        print(f"Plot saved to '{output_image}'.")

    # plot bias change for original and paraphrased text
    plot_single_metric(df, "text_likert_mean", "Paraphrased Text Bias Change Trajectory",
                       f"{data_prefix}_text_bias_trajectory.png", "Bias (Likert Mean)")
    plot_single_metric(df, "orig_text_likert_mean", "Original Text Bias Change Trajectory",
                       f"{data_prefix}_orig_text_bias_trajectory.png", "Bias (Likert Mean)")

    # Plot standard deviation change for original and paraphrased text
    plot_single_metric(df, "text_likert_std", "Paraphrased Text Variability Trajectory",
                       f"{data_prefix}_text_variability_trajectory.png", "Standard Deviation")
    plot_single_metric(df, "orig_text_likert_std", "Original Text Variability Trajectory",
                       f"{data_prefix}_orig_text_variability_trajectory.png", "Standard Deviation")
    
    # ======= Plot Overall Bias & Variability Change =======
    metrics = ["Bias Change", "Std Change"]
    text_values = overall_bias_df["Paraphrased Text"].values
    orig_text_values = overall_bias_df["Original Text"].values

    x = np.arange(len(metrics))

    plt.figure(figsize=(10, 6))
    width = 0.3  # Width of bars

    # plot Bar Chart
    plt.bar(x - width/2, text_values, width, label="Paraphrased Text", color="blue", alpha=0.7)
    plt.bar(x + width/2, orig_text_values, width, label="Original Text", color="red", alpha=0.7)

    # labels and formatting
    plt.xticks(x, metrics)
    plt.ylabel("Change in Bias / Variability")
    plt.title("Overall Bias and Variability Change")
    plt.axhline(0, color="black", linewidth=0.8)  # zero line for reference
    plt.legend()
    plt.grid(axis="y", linestyle="--", alpha=0.7)

    overall_output_image = os.path.join(output_path, f"{data_prefix}_bias_variability_overall.png")
    plt.savefig(overall_output_image)
    plt.close()
    print(f"Overall bias and variability plot saved to '{overall_output_image}'.")


def main(task):
    """
    Main function to compute and evaluate semantic similarity.
    Allows running specific tasks based on user input.
    """
    try:
        os.makedirs(output_path, exist_ok=True)
        output_csv_path = os.path.join(output_path, f"{data_prefix}_semantic_similarity.csv")
        bias_output_csv = os.path.join(output_path, f"{data_prefix}_bias_variability.csv")
        overall_bias_output_csv = os.path.join(output_path, f"{data_prefix}_bias_variability_overall.csv")

        # compute semantic similarity
        if task == "evaluate" or task == "all":
            evaluated_df = evaluate_similarity(augmented_file)
            summarize_similarity(evaluated_df)
            plot_similarity_distribution(evaluated_df, output_path)
            
            # predict likert values
            topic = extract_topic(data_prefix)
            # print(f"Extracted topic: {topic}")
            evaluated_df = predict(evaluated_df, topic)
            evaluated_df.to_csv(output_csv_path, index=False)
            print(f"Semantic similarity results saved to '{output_csv_path}'.")

        else:
            # load previously computed dataset if skipping evaluation
            if os.path.exists(output_csv_path):
                evaluated_df = pd.read_csv(output_csv_path)
                print(f"Loaded existing semantic similarity results from '{output_csv_path}'.")
            else:
                raise FileNotFoundError("No existing semantic similarity results found. Please run with 'evaluate' first.")

        # plot trajectories
        if task == "similarity" or task == "all":
            plot_similarity_trajectory(evaluated_df, output_image)

        if task == "opinion" or task == "all":
            plot_opinion_trajectory(evaluated_df, output_path, data_prefix)

        # compute Bias & variability nalysis
        if task == "bias" or task == "all":
            compute_bias_and_variability(evaluated_df, bias_output_csv)
            print(f"Bias and variability results saved to '{bias_output_csv}'.")
     
            bias_df = pd.read_csv(bias_output_csv)
            overall_bias_df = pd.read_csv(overall_bias_output_csv)
            plot_bias_and_variability(bias_df, overall_bias_df, output_path, data_prefix)

    except Exception as e:
        print(f"Error in main function: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run specific tasks in the script.")
    parser.add_argument(
        "task",
        choices=["all", "evaluate", "similarity", "opinion", "bias"],
        help="Specify which task to run: all, evaluate, similarity, opinion, bias",
    )
    args = parser.parse_args()
    main(args.task)

