import sys
import os
import pandas as pd
import datasets
import functools
import typing
from . import util


def pred_label_from_openai(row: dict, client, model_name: str, labels: typing.List[str]):
    """A helper function to generate the predictions for the given batch of rows using the model and tokenizer through OpenAI models.

    Args:
        row (dict): HuggingFace Datasets batch
        client: OpenAI client for generation
        model_name (str): OpenAI model name for generation
        labels (typing.List[str]): labels for classification for sanity check
    Returns:
        dict: "label_pred" field for list of predicted labels
    """
    # Generate the predictions
    messages = [
        {"role": "user", "content": row["prompt"]},
    ]
    completion = client.chat.completions.create(
        model=model_name,
        messages=messages,
        max_tokens=10,
        temperature=0.0,  # deterministic
    )
    prediction = completion.choices[0].message.content.replace('"', '').strip()
    
    # Sanity check: ensure the predictions are valid labels specified for the classification task
    if prediction not in labels:
        print(f"[WARN] [pred_label_from_openai] Invalid prediction: [{prediction}]", file=sys.stderr)
        prediction = None
    
    return { "label_pred": prediction }


def pred_label_from_seq2seq(rows: dict, model, tokenizer, labels: typing.List[str]):
    """A helper function to generate the predictions for the given batch of rows using the model and tokenizer through Seq2Seq generation models.

    Args:
        rows (dict): HuggingFace Datasets batch
        model (AutoModelForSeq2SeqLM): HuggingFace model for generation
        tokenizer (AutoTokenizer): HuggingFace tokenizer of the model
        prompt (str): the prompt template to insert the tweet
        labels (typing.List[str]): labels for classification for sanity check
    Returns:
        dict: "label_pred" field for list of predicted labels
    """
    # Generate the predictions
    texts = rows["prompt"]
    inputs = tokenizer(texts, return_tensors="pt", padding=True)
    inputs.to(model.device)
    outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=10)
    predictions = tokenizer.batch_decode(outputs["sequences"], skip_special_tokens=True)
    
    # Sanity check: ensure the predictions are valid labels specified for the classification task
    for i in range(len(predictions)):
        if predictions[i] not in labels:
            print(f"[WARN] [pred_label_from_generation] Invalid prediction: [{predictions[i]}]", file=sys.stderr)
            predictions[i] = None
    
    return { "label_pred": predictions }


def pred_label_from_chat(row: dict, pipe, labels: typing.List[str]):
    """A helper function to generate the predictions for the given batch of rows using the model and tokenizer through chat models.

    Args:
        row (dict): HuggingFace Datasets batch
        pipe: HuggingFace pipeline for generation
        prompt (str): the prompt template to insert the tweet
        labels (typing.List[str]): labels for classification for sanity check
    Returns:
        dict: "label_pred" field for list of predicted labels
    """
    # Generate the predictions
    messages = [
        {"role": "user", "content": row["prompt"]},
    ]
    prediction = pipe(messages, max_new_tokens=10)[0]["generated_text"][1]["content"].replace('"', '').strip()
    
    # Sanity check: ensure the predictions are valid labels specified for the classification task
    if prediction not in labels:
        print(f"[WARN] [pred_label_from_chat] Invalid prediction: [{prediction}]", file=sys.stderr)
        prediction = None
    
    return { "label_pred": prediction }


def main(data_prefix: str, topic: str, eval_model_config, version: str, exists_llm: bool, llm_column: bool, memory_enabled: bool, llm_name: typing.Optional[str] = None):
    if exists_llm:
        # Load both human and LLM data when LLM data exists, then apply filtering
        human_df, llm_df = util.load_simulation_dfs(
            data_prefix=data_prefix,
            model_name=llm_name,
            version=version,
            filter_strategy="any",  # Filter if either human or LLM message is invalid
            preprocess=True,
            consecutive_messages=True
        )
        
        # Choose which dataframe to use based on llm_column flag
        simulation_df = llm_df if llm_column else human_df
        field_name = "llm_text" if llm_column else "text"
    else:
        # Load human data when LLM data doesn't exist
        simulation_df = util.load_simulation_df(
            data_prefix=data_prefix,
            model_name=None,  # Human data
            eval_model_name=None,
            version="human",
            filter_invalid=False,
            load_validity=False
        )
        llm_column = False  # Force using human text if no LLM data exists
        field_name = "text"
    
    prompt_no_memory = f'What is the stance of the following message with respect to the topic "{topic}"? Here is the message: "{{TEXT}}". Please use exactly one word from the following 6 categories to label it: "Certainly agree", "Probably agree", "Lean agree", "Lean disagree", "Probably disagree", and "Certainly disagree". Your predicted label: '
    if memory_enabled:
        prompt_with_memory = f'Here is prior conversation around the topic "{topic}":\n{{CONVERSATION}}\nWhat is the stance of the following new message by {{PLAYER}} with respect to the topic "{topic}"? Here is the message: "{{TEXT}}". Please use exactly one word from the following 6 categories to label it: "Certainly agree", "Probably agree", "Lean agree", "Lean disagree", "Probably disagree", and "Certainly disagree". Your predicted label: '
    label2likert = {
        "Certainly agree": 6,
        "Probably agree": 5,
        "Lean agree": 4,
        "Lean disagree": 3,
        "Probably disagree": 2,
        "Certainly disagree": 1,
    }
    likert2label = {v: k for k, v in label2likert.items()}

    # Process the data
    simulation_df = util.preprocess_simulation_df(simulation_df, consecutive_messages=True)
    simulation_df, _ = util.get_chat_order_and_separators(simulation_df)
    # empirica_id_to_player_id = simulation_df[["empirica_id", "worker_id"]].drop_duplicates().set_index("empirica_id")["worker_id"].to_dict()  # map: empirica_id -> worker_id
    n_rounds = int(simulation_df["chat_round_order"].max())
    simulation_df["message"] = simulation_df[field_name]
    simulation_df["likert_pred"] = None
    simulation_df["label_pred"] = None
    message_events = simulation_df[simulation_df["event_type"].isin(["tweet", "message_sent"])].copy()

    concatenate_messages_event_order = []
    concatenate_messages_content = []

    # for every round
    for round_num in range(1, n_rounds + 1):
        round_events = message_events[message_events["chat_round_order"] == round_num]

        # generate pairs of players (A, B) without (B, A) duplicates
        pairs = round_events.groupby(["sender_id", "recipient_id"]).size().reset_index().iloc[:, :2]
        pairs = pairs.apply(lambda row: tuple(sorted([row["sender_id"], row["recipient_id"]])), axis=1).drop_duplicates().tolist()

        for player_a, player_b in pairs:
            # filter messages between player A and player B
            messages = round_events[(round_events["sender_id"] == player_a) | (round_events["sender_id"] == player_b)].sort_values("event_order").reset_index()
            concatenate_message = ""

            # concatenate the messages between the two players in the round so far
            for i in range(1, messages.shape[0]):
                concatenate_messages_event_order.append(messages.iloc[i]["event_order"])
                concatenate_message += f'{messages.iloc[i - 1]["sender_id"]}: {messages.iloc[i - 1]["message"]}\n'
                concatenate_messages_content.append(concatenate_message)

    message_events = message_events.merge(pd.DataFrame({
        "event_order": concatenate_messages_event_order,
        "conversation": concatenate_messages_content
    }), on="event_order", how="left")

    def get_prompt(row):
        if pd.isna(row["conversation"]) or row["conversation"] == "" or not memory_enabled:
            return prompt_no_memory.replace("{TEXT}", str(row["message"]))
        return prompt_with_memory.replace("{TEXT}", str(row["message"])).replace("{CONVERSATION}", str(row["conversation"])).replace("{PLAYER}", str(row["sender_id"]))

    message_events.loc[:, "prompt"] = message_events.apply(get_prompt, axis=1)

    initial_opinions_idx = simulation_df["event_type"] == "Initial Opinion"
    post_opinion_idx = simulation_df["event_type"] == "Post Opinion"
    simulation_df.loc[initial_opinions_idx | post_opinion_idx, "conversation"] = ""
    simulation_df.loc[initial_opinions_idx, "prompt"] = simulation_df.loc[initial_opinions_idx].apply(get_prompt, axis=1)  # without memory
    simulation_df.loc[post_opinion_idx, "prompt"] = simulation_df.loc[post_opinion_idx].apply(get_prompt, axis=1)
    
    opinions = simulation_df.loc[initial_opinions_idx | post_opinion_idx]
    message_events = pd.concat([message_events, opinions], ignore_index=True).sort_values("event_order")
    simulation_ds = datasets.Dataset.from_pandas(message_events[["event_order", "conversation", "message", "prompt", "label_pred"]])
    
    if len(eval_model_config) == 2:
        eval_model_type, eval_model_name = eval_model_config
        eval_model_save_name = eval_model_name
    elif len(eval_model_config) == 4:
        eval_model_type, eval_model_name, eval_model, eval_tokenizer = eval_model_config
        eval_model_save_name = eval_model_name.split("/")[1]
    
    eval_ds = simulation_ds.filter(lambda x: x["prompt"] is not None)
    done_ds = simulation_ds.filter(lambda x: x["prompt"] is None)
    
    if eval_model_type == "openai":
        import openai
        if not os.getenv("OPENAI_API_KEY"):
            with open("openai-key.txt", "r") as f:
                os.environ["OPENAI_API_KEY"] = f.read().strip()
        client = openai.OpenAI()
        callback_ds = eval_ds.map(functools.partial(pred_label_from_openai,
                client=client, model_name=eval_model_name, labels=list(label2likert.keys())
            ), desc="Generating opinion predictions (OpenAI)")
    elif eval_model_type == "seq2seq":
        from transformers import set_seed
        set_seed(42)
        callback_ds = eval_ds.map(functools.partial(pred_label_from_seq2seq,
                model=eval_model, tokenizer=eval_tokenizer, labels=list(label2likert.keys())
            ), batched=True, batch_size=4, desc="Generating opinion predictions (HuggingFace T5)")
    else:  # generation
        from transformers import pipeline, set_seed
        set_seed(42)
        pipe = pipeline("text-generation", model=eval_model, tokenizer=eval_tokenizer, device_map="auto")
        callback_ds = eval_ds.map(functools.partial(pred_label_from_chat,
                pipe=pipe, labels=list(label2likert.keys())
            ), desc="Generating opinion predictions (HuggingFace Chat)")
    
    result_ds = datasets.concatenate_datasets([done_ds, callback_ds]).sort("event_order")
    message_events.loc[:, "label_pred"] = result_ds["label_pred"]
    message_events.loc[:, "likert_pred"] = message_events["label_pred"].map(label2likert)
    if not exists_llm:
        util.save_csv(message_events, f"../../result/eval/human/{data_prefix}/opinion{'_memory' if memory_enabled else ''}_{eval_model_save_name}_{version}.csv")
    else:
        if llm_column:
            util.save_csv(message_events, f"../../result/eval/human_llm/{data_prefix}/{llm_name}/opinion_llm{'_memory' if memory_enabled else ''}_{eval_model_save_name}_{version}.csv")
        else:
            util.save_csv(message_events, f"../../result/eval/human_llm/{data_prefix}/{llm_name}/opinion_human{'_memory' if memory_enabled else ''}_{eval_model_save_name}_{version}.csv")
