import argparse
import json
import re
import sys

import numpy as np
import pandas as pd
from anthropic import Anthropic
from openai import OpenAI

import wandb

sys.path.append("../")

treatment = """We are now going to ask you to imagine you have read about the following scenario, describing a debate on a recent College Campus. Antifa Denied Permit to Protest on Campus, Provoking Debate About “Cancel Culture” Columbus, OH A debate on the merits of free speech erupted recently when the student chapter of the controversial far-left group Antifa attempted to obtain a permit to conduct a demonstration on the main quad of Rutgers University in New Jersey. Citing safety concerns, the president of the organization in charge of Registered Student Organizations (RSOs) initially denied the organization the right to conduct their rally, arguing that their presence would endanger college students. They cited a recent incident in Berkeley, CA where three Antifa members and two bystanders were injured by rocks thrown in an altercation between the group and counter protesters. A member of the local Antifa group, Luke Vargas, is appealing the decision, arguing that the permit denial represented "cancel culture run amok," and the University was simply "afraid to hear the truth." When asked to comment, the University Ombudsman's Office promised that a final decision on whether the rally would be permitted would be made by this Thursday, three days before the march is scheduled to take place on Sunday."""
control = """We are now going to ask you to imagine you have read about the following scenario, describing a debate on a recent College Campus. Antifa Denied Permit to Protest on Campus Columbus, OH A debate on the merits of free speech erupted recently when the student chapter of the controversial far-left group Antifa attempted to obtain a permit to conduct a demonstration on the main quad of Rutgers University in New Jersey. Citing safety concerns, the president of the organization in charge of Registered Student Organizations (RSOs) initially denied the organization the right to conduct their rally, arguing that their presence would endanger college students. They cited a recent incident in Berkeley, CA where three Antifa members and two bystanders were injured by rocks thrown in an altercation between the group and counter protesters. A member of the local Antifa group, Luke Vargas, promised to bring an appeal to the desk of the University President. When asked to comment, the University Ombudsman's Office promised that a final decision on whether the rally would be permitted would be made by this Thursday, three days before the march is scheduled to take place on Sunday."""

prompts = [
    "Based on your perspective and the situation, how would you assess this scenario?",
    "Taking into account the family's context, how would you evaluate the priority level?",
    "Given the background and details provided, what numerical response feels appropriate?",
    "Considering all relevant factors, where would you place this on the scale?",
    "Reflecting on the scenario and using your reasoning, what value would you assign?",
    "From your understanding of the situation, how would you quantify this priority?",
    "Given your insights and the context described, what would be your evaluation?",
    "With the provided details in mind, how would you rate this on the scale?",
    "Considering all the information and your perspective, what score seems suitable?",
    "If you were to evaluate the priority here, what number aligns with your reasoning?",
]


def get_ideology_string(ideo_category):
    ideo_map = {1: "Very liberal", 2: "Somewhat liberal", 3: "Moderate", 4: "Somewhat conservative", 5: "Very conservative"}
    return ideo_map.get(ideo_category, "Unknown")


value_mappings = {
    "GENDER": {0: "Unknown", 1: "Male", 2: "Female"},
    "RACETHNICITY": {1: "White, non-Hispanic", 2: "Black, non-Hispanic", 3: "Other, non-Hispanic", 4: "Hispanic", 5: "2+, non-Hispanic", 6: "Asian, non-Hispanic"},
    "HOME_TYPE": {
        1: "A one-family house detached from any other house",
        2: "A one-family house attached to one or more houses",
        3: "A building with 2 or more apartments",
        4: "A mobile home or trailer",
        5: "Boat, RV, van, etc",
    },
    "PARTYID7": {
        -1: "Unknown",
        1: "Strong Democrat",
        2: "Not so strong Democrat",
        3: "Lean Democrat",
        4: "Don't Lean/Independent/None",
        5: "Lean Republican",
        6: "Not so strong Republican",
        7: "Strong Republican",
    },
    "RELIG": {
        1: "Protestant",
        2: "Roman Catholic",
        3: "Mormon",
        4: "Orthodox",
        5: "Jewish",
        6: "Muslim",
        7: "Buddhist",
        8: "Hindu",
        9: "Atheist",
        10: "Agnostic",
        11: "Nothing in particular",
        12: "Just Christian",
        13: "Unitarian",
        14: "Something else",
        77: "DON'T KNOW",
        98: "SKIPPED ON WEB",
        99: "REFUSED",
    },
    "ATTEND": {
        1: "Never",
        2: "Less than once per year",
        3: "About once or twice a year",
        4: "Several times a year",
        5: "About once a month",
        6: "2-3 times a month",
        7: "Nearly every week",
        8: "Every week",
        9: "Several times a week",
        77: "DON'T KNOW",
        98: "SKIPPED ON WEB",
        99: "REFUSED",
    },
}


def extract_numeric_response(response_string):
    """Extract a numeric response from structured or free text."""
    try:
        response_dict = json.loads(response_string)
        if isinstance(response_dict, dict):
            for value in response_dict.values():
                if isinstance(value, (int, float)):
                    return int(value)
        elif isinstance(response_dict, (int, float)):
            return int(response_dict)
    except (json.JSONDecodeError, KeyError):
        pass
    numbers = re.findall(r"\d+", response_string)
    if numbers:
        return int(numbers[0])
    return 3.0  # middle value


def get_ensemble_prediction_averaged(base_query, scenario, common_part, model_type, clients):
    responses = []
    for prompt in prompts:
        full_query = scenario + prompt + common_part + " Reply with a single integer!"
        response = get_llm_response(full_query, base_query, model_type, clients)
        if response is not None:
            responses.append(response)
    if not responses:
        print("All responses were invalid!")
        return 3.0
    return responses


def calculate_running_mse(y_true, predictions_array):
    mse_scores = []
    for i in range(1, len(predictions_array) + 1):
        current_pred = sum(predictions_array[:i]) / i
        mse = (y_true - current_pred) ** 2
        mse_scores.append(mse)
    return mse_scores


def generate_synthetic_data(df, model_type, clients):
    mse_y1 = {i: [] for i in range(1, len(prompts) + 1)}
    mse_y0 = {i: [] for i in range(1, len(prompts) + 1)}

    df["Y1hat_responses"] = None
    df["Y0hat_responses"] = None
    df = df.astype({"Y1hat_responses": "object", "Y0hat_responses": "object"})

    for index, row in df.iterrows():

        base_query = (
            f"You are a {row['AGE']} year old, {value_mappings['PARTYID7'][int(row['PARTYID7'])]}, "
            f"gender {value_mappings['GENDER'][int(row['GENDER'])]}, and hold {get_ideology_string(row['IDEO'])} views. "
            f"Additionally, your religion is {value_mappings['RELIG'][int(row['RELIG'])]} and you "
            f"{value_mappings['ATTEND'][int(row['ATTEND'])]} attend religious services. "
            f"You reside in {value_mappings['HOME_TYPE'][int(row['HOME_TYPE'])]}. "
            f"You are responding to a scenario reflecting a debate involving college campus events and broader social issues. "
            f"Your answer must be a single integer without text"
        )

        common_part = """ Generally speaking, do you agree or disagree with the following statement:
        “Cancel culture is a big problem in today’s society.”
        Choose an integer between 1 (definitely agree) and 5 (definetely disagree) to state your opinions.  """

        y1_responses = get_ensemble_prediction_averaged(base_query, treatment, common_part, model_type, clients)
        y0_responses = get_ensemble_prediction_averaged(base_query, control, common_part, model_type, clients)

        # Accumulate MSE scores only for relevant treatment status
        if row['T'] == 1:
            mse_y1_scores = calculate_running_mse(row['Y'], y1_responses)
            for i, score in enumerate(mse_y1_scores, 1):
                mse_y1_scores = calculate_running_mse(row['Y'], y1_responses)
                mse_y1[i].append(score)
        else:
            mse_y0_scores = calculate_running_mse(row['Y'], y0_responses)
            for i, score in enumerate(mse_y0_scores, 1):
                mse_y0[i].append(score)
        # Calculate final predictions
        y1 = np.mean(y1_responses)
        y0 = np.mean(y0_responses)

        df.at[index, "Y1hat"] = y1
        df.at[index, "Y1hat_responses"] = str(y1_responses)
        df.at[index, "Y0hat"] = y0
        df.at[index, "Y0hat_responses"] = str(y0_responses)

        print(f"Processed individual {index + 1}/{len(df)}")
        print(f"Y1 (averaged): {y1}")
        print(f"Y0 (averaged): {y0}")
        print(f"Y: {row['Y']}")
        print(f"T: {row['T']}")
        print("---")

    # Log average MSE for each number of prompts
    for num_prompts in range(1, len(prompts) + 1):
        avg_mse_y1 = sum(mse_y1[num_prompts]) / len(mse_y1[num_prompts])
        avg_mse_y0 = sum(mse_y0[num_prompts]) / len(mse_y0[num_prompts])

        wandb.log({
            "num_prompts": num_prompts,
            "mse_treatment": avg_mse_y1,
            "mse_control": avg_mse_y0
        })
    return df


def get_claude_response(prompt, base_prompt, anthropic_client):
    message = anthropic_client.messages.create(model="claude-3-5-haiku-20241022",
                                               max_tokens=50,
                                               system=base_prompt,
                                               temperature=1.0,
                                               messages=[{"role": "user", "content": prompt}])
    return extract_numeric_response(message.content[0].text)


def get_gpt_response(prompt, base_prompt, openai_client):
    response = openai_client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "system", "content": base_prompt}, {"role": "user", "content": prompt}],
    )
    return extract_numeric_response(response.choices[0].message.content)


def get_deepseek_response(prompt, base_prompt, openai_client):
    response = openai_client.chat.completions.create(
        model="deepseek-chat",
        messages=[{"role": "system", "content": base_prompt}, {"role": "user", "content": prompt}],
    )
    print(f"{response.choices[0].message.content}")
    return extract_numeric_response(response.choices[0].message.content)


def get_llm_response(prompt, base_prompt, model_type, clients):
    if model_type == "claude_haiku":
        return get_claude_response(prompt, base_prompt, clients["claude_haiku"])
    if model_type == "deepseek":
        return get_deepseek_response(prompt, base_prompt, clients["deepseek"])
    else:
        return get_gpt_response(prompt, base_prompt, clients["gpt4o"])


def main():
    parser = argparse.ArgumentParser(description="Generate synthetic data using LLM models")
    parser.add_argument("--model", type=str,
                        choices=["gpt4o", "claude_haiku", "deepseek"], required=True)
    parser.add_argument("--entity", type=str, required=True)

    args = parser.parse_args()
    wandb.init(
        project="faheyS78_LLM_Generation",
        entity=args.entity,
        config=args
    )

    # TODO add your API tokens
    gpt_key = "xxx"
    claude_key = "yyy"
    deepseek_key = "zzz"

    # Initialize clients based on model choice
    clients = {}
    if args.model == "claude_haiku":
        clients["claude_haiku"] = Anthropic(api_key=claude_key)
    if args.model == "deepseek":
        clients["deepseek"] = OpenAI(api_key=deepseek_key, base_url="https://api.deepseek.com")
    else:
        clients["gpt4o"] = OpenAI(api_key=gpt_key)

    df = pd.read_csv("df_processed.csv")
    result_df = generate_synthetic_data(df, args.model, clients)

    # Save results
    output_path = f"df_{args.model}.csv"
    result_df.to_csv(output_path, index=False)

    wandb_table = wandb.Table(dataframe=result_df)
    wandb.log({"final_dataframe": wandb_table})


if __name__ == "__main__":
    main()
