import json
import re

import numpy as np
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline

import wandb

chat_pipeline = None


def load_pipeline(name):
    model_configs = {
        "llama": {
            "model_name": "meta-llama/Meta-Llama-3-70B-Instruct",
            "quantization_config": BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4"
            )
        },
        "llama_small": {
            "model_name": "meta-llama/Meta-Llama-3-8B-Instruct",
            "quantization_config": BitsAndBytesConfig(
                load_in_4bit=False,
            )
        },
    }
    if name not in model_configs:
        raise ValueError(f"Invalid model name: {name}.")
    config = model_configs[name]
    model_name = config["model_name"]
    quantization_config = config["quantization_config"]
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    global chat_pipeline
    chat_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
    chat_pipeline.tokenizer.pad_token_id = model.config.eos_token_id


prompts = [
    "Based on your perspective and the situation, how would you assess this scenario?",
    "Taking into account the family's context, how would you evaluate the priority level?",
    "Given the background and details provided, what numerical response feels appropriate?",
    "Considering all relevant factors, where would you place this on the scale?",
    "Reflecting on the scenario and using your reasoning, what value would you assign?",
    "From your understanding of the situation, how would you quantify this priority?",
    "Given your insights and the context described, what would be your evaluation?",
    "With the provided details in mind, how would you rate this on the scale?",
    "Considering all the information and your perspective, what score seems suitable?",
    "If you were to evaluate the priority here, what number aligns with your reasoning?",
]


treatment = """We are now going to ask you to imagine you have read about the following scenario, describing a debate on a recent College Campus. Antifa Denied Permit to Protest on Campus, Provoking Debate About “Cancel Culture” Columbus, OH A debate on the merits of free speech erupted recently when the student chapter of the controversial far-left group Antifa attempted to obtain a permit to conduct a demonstration on the main quad of Rutgers University in New Jersey. Citing safety concerns, the president of the organization in charge of Registered Student Organizations (RSOs) initially denied the organization the right to conduct their rally, arguing that their presence would endanger college students. They cited a recent incident in Berkeley, CA where three Antifa members and two bystanders were injured by rocks thrown in an altercation between the group and counter protesters. A member of the local Antifa group, Luke Vargas, is appealing the decision, arguing that the permit denial represented "cancel culture run amok," and the University was simply "afraid to hear the truth." When asked to comment, the University Ombudsman's Office promised that a final decision on whether the rally would be permitted would be made by this Thursday, three days before the march is scheduled to take place on Sunday."""

control = """We are now going to ask you to imagine you have read about the following scenario, describing a debate on a recent College Campus. Antifa Denied Permit to Protest on Campus Columbus, OH A debate on the merits of free speech erupted recently when the student chapter of the controversial far-left group Antifa attempted to obtain a permit to conduct a demonstration on the main quad of Rutgers University in New Jersey. Citing safety concerns, the president of the organization in charge of Registered Student Organizations (RSOs) initially denied the organization the right to conduct their rally, arguing that their presence would endanger college students. They cited a recent incident in Berkeley, CA where three Antifa members and two bystanders were injured by rocks thrown in an altercation between the group and counter protesters. A member of the local Antifa group, Luke Vargas, promised to bring an appeal to the desk of the University President. When asked to comment, the University Ombudsman's Office promised that a final decision on whether the rally would be permitted would be made by this Thursday, three days before the march is scheduled to take place on Sunday."""


def get_ideology_string(ideo_category):
    ideo_map = {1: "Very liberal", 2: "Somewhat liberal", 3: "Moderate", 4: "Somewhat conservative", 5: "Very conservative"}
    return ideo_map.get(ideo_category, "Unknown")


value_mappings = {
    "GENDER": {0: "Unknown", 1: "Male", 2: "Female"},
    "RACETHNICITY": {1: "White, non-Hispanic", 2: "Black, non-Hispanic", 3: "Other, non-Hispanic", 4: "Hispanic", 5: "2+, non-Hispanic", 6: "Asian, non-Hispanic"},
    "HOME_TYPE": {
        1: "A one-family house detached from any other house",
        2: "A one-family house attached to one or more houses",
        3: "A building with 2 or more apartments",
        4: "A mobile home or trailer",
        5: "Boat, RV, van, etc",
    },
    "PARTYID7": {
        -1: "Unknown",
        1: "Strong Democrat",
        2: "Not so strong Democrat",
        3: "Lean Democrat",
        4: "Don't Lean/Independent/None",
        5: "Lean Republican",
        6: "Not so strong Republican",
        7: "Strong Republican",
    },
    "RELIG": {
        1: "Protestant",
        2: "Roman Catholic",
        3: "Mormon",
        4: "Orthodox",
        5: "Jewish",
        6: "Muslim",
        7: "Buddhist",
        8: "Hindu",
        9: "Atheist",
        10: "Agnostic",
        11: "Nothing in particular",
        12: "Just Christian",
        13: "Unitarian",
        14: "Something else",
        77: "DON'T KNOW",
        98: "SKIPPED ON WEB",
        99: "REFUSED",
    },
    "ATTEND": {
        1: "Never",
        2: "Less than once per year",
        3: "About once or twice a year",
        4: "Several times a year",
        5: "About once a month",
        6: "2-3 times a month",
        7: "Nearly every week",
        8: "Every week",
        9: "Several times a week",
        77: "DON'T KNOW",
        98: "SKIPPED ON WEB",
        99: "REFUSED",
    },
}


def extract_numeric_response(response_string):
    """Extract a numeric response from structured or free text."""
    try:
        response_dict = json.loads(response_string)
        if isinstance(response_dict, dict):
            for value in response_dict.values():
                if isinstance(value, (int, float)):
                    return int(value)
        elif isinstance(response_dict, (int, float)):
            return int(response_dict)
    except (json.JSONDecodeError, KeyError):
        pass
    numbers = re.findall(r"\d+", response_string)
    if numbers:
        return int(numbers[0])
    return 3.0  # middle value


def get_template(persona, scenario_text, question):
    messages = [
        {
            "role": "system",
            "content": f"{persona}. Your answer must be in JSON format with an integer, without additional text."
        },
        {
            "role": "user",
            "content": f"{scenario_text}\n\n{question}"
        }
    ]
    prompt = chat_pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    return prompt


def process_outputs(example):
    outputs = chat_pipeline(
        example["text"],
        max_new_tokens=100,
        eos_token_id=[chat_pipeline.tokenizer.eos_token_id, chat_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")],
        do_sample=True,
        temperature=1.2,
        top_p=0.9,
        batch_size=len(example["text"]),
    )

    result_texts = []
    for prompt, output_item in zip(example["text"], outputs):
        response_str = output_item[0]["generated_text"]
        answer_str = response_str[len(prompt) :].strip()
        print(answer_str)
        numeric_val = extract_numeric_response(answer_str)
        if numeric_val is not None and 1 <= numeric_val <= 5:
            result_texts.append(numeric_val)
        else:
            result_texts.append(None)
    return {"responses": result_texts}


def get_ensemble_prediction(persona, scenario_text, prompts, question=""):
    batch_prompts = []
    for p in prompts:
        full_question = question + "\n" + p
        model_prompt = get_template(persona, scenario_text, full_question)
        batch_prompts.append(model_prompt)

    ds = Dataset.from_dict({"text": batch_prompts})
    ds = ds.map(process_outputs, batched=True, batch_size=len(prompts))
    valid_responses = [r for r in ds["responses"] if r is not None]
    if valid_responses:
        return np.mean(valid_responses), valid_responses
    else:
        return None, []


def calculate_running_mse(y_true, predictions_array):
    mse_scores = []
    for i in range(1, len(predictions_array) + 1):
        current_pred = sum(predictions_array[:i]) / i
        mse = (y_true - current_pred) ** 2
        mse_scores.append(mse)
    return mse_scores


def generate_synthetic_data(df, model_name):

    load_pipeline(model_name)

    rating_instruction = """Generally speaking, do you agree or disagree with the following statement:
“Cancel culture is a big problem in today’s society.”
Reply with a JSON numerical answer using numbers between 1 (definitely agree) and 5 (definitely disagree)."""

    mse_y1 = {i: [] for i in range(1, len(prompts) + 1)}
    mse_y0 = {i: [] for i in range(1, len(prompts) + 1)}

    df["Y1hat_responses"] = None
    df["Y0hat_responses"] = None
    df = df.astype({"Y1hat_responses": "object", "Y0hat_responses": "object"})
    for index, row in df.iterrows():
        # Persona construction
        persona = (
            f"You are a {row['AGE']} year old, {value_mappings['PARTYID7'][int(row['PARTYID7'])]}, "
            f"gender {value_mappings['GENDER'][int(row['GENDER'])]}, and hold {get_ideology_string(row['IDEO'])} views. "
            f"Additionally, your religion is {value_mappings['RELIG'][int(row['RELIG'])]} and you "
            f"{value_mappings['ATTEND'][int(row['ATTEND'])]} attend religious services. "
            f"You reside in {value_mappings['HOME_TYPE'][int(row['HOME_TYPE'])]}. "
            f"You are responding to a scenario reflecting a debate involving college campus events and broader social issues. "
            f"Your answer must be a single integer without additional text, in JSON format with a key-value pair"
        )
        # Treatment scenario with ensemble prompts
        _, y1_responses = get_ensemble_prediction(persona, treatment, prompts, question=rating_instruction)

        # Control scenario with ensemble prompts
        _, y0_responses = get_ensemble_prediction(persona, control, prompts, question=rating_instruction)

        if row['T'] == 1:
            mse_y1_scores = calculate_running_mse(row['Y'], y1_responses)
            for i, score in enumerate(mse_y1_scores, 1):
                mse_y1_scores = calculate_running_mse(row['Y'], y1_responses)
                mse_y1[i].append(score)
        else:
            mse_y0_scores = calculate_running_mse(row['Y'], y0_responses)
            for i, score in enumerate(mse_y0_scores, 1):
                mse_y0[i].append(score)

        # Calculate final predictions
        y1 = np.mean(y1_responses)
        y0 = np.mean(y0_responses)

        # Store results
        df.at[index, "Y1hat"] = y1
        df.at[index, "Y0hat"] = y0
        df.at[index, "Y1hat_responses"] = y1_responses
        df.at[index, "Y0hat_responses"] = y0_responses

        print(f"Processed {index + 1}/{len(df)}")
        print(f"Y1 (treatment): {y1}, Responses: {y1_responses}")
        print(f"Y0 (control): {y0}, Responses: {y0_responses}")
        print(f"Y: {row['Y']}")
        print(f"T: {row['T']}")
        print("---")

    # Log average MSE for each number of prompts
    for num_prompts in range(1, len(prompts) + 1):
        avg_mse_y1 = sum(mse_y1[num_prompts]) / len(mse_y1[num_prompts])
        avg_mse_y0 = sum(mse_y0[num_prompts]) / len(mse_y0[num_prompts])

        wandb.log({
            "num_prompts": num_prompts,
            "mse_treatment": avg_mse_y1,
            "mse_control": avg_mse_y0
        })
    return df
