import asyncio
import click
import pandas as pd
import os
from typing import List
from zeno_build.models.lm_config import LMConfig
from zeno_build.models.text_generate import multiple_generate_from_text_prompt

from text_generation import AsyncClient
from utils import Bias, Response

model_fxn = {
    "llama2-7b": "run_llm",
    "llama2-13b": "run_llm",
    "llama2-70b": "run_llm",
    "llama2-70b-ift": "run_llm",
    "llama2-70b-chat": "run_llm",
    "gpt-3.5-turbo": "run_zeno",
    "gpt-3.5-turbo-instruct": "run_zeno",
}


def run_llm(
    bias: str,
    model: str,
    filename: str,
    key: str,
    data: pd.DataFrame,
    all_preds: pd.DataFrame,
    num_samples: int,
    perturbation=None,
    checkpoint_df=None,
) -> List[str]:

    async def batch_generate(prompt: str):
        return await asyncio.gather(
            *[
                client.generate(
                    prompt,
                    max_new_tokens=1,
                    temperature=1,
                    do_sample=True,
                )
                for sample in range(100)
            ]
        )

    print(f"Key: {key}")
    if not os.path.exists(f"checkpoint/{model}"):
        os.makedirs(f"checkpoint/{model}")

    model_addr = 'model_address_placeholder'

    client = AsyncClient(f"http://{model_addr}")

    question_df = data[data["key"] == key]
    print(question_df.head(), flush=True)
    assert len(question_df.index == 1), "question key should be unique"
    columns = Bias.get_col_names(bias)
    for index, row in question_df.iterrows():
        for column in columns:
            print(column, flush=True)
            question = row[column]
            num_options = row["num options"]
            if Bias.change_num_options(column) and perturbation is None:
                num_options = row["num options new"]
            found = False
            if checkpoint_df is not None:
                saved_row = checkpoint_df.loc[
                    (checkpoint_df["key"] == key) & (checkpoint_df["type"] == column)
                ]
                if len(saved_row.index) == 1:
                    print("row found", flush=True)
                    responses = saved_row.iloc[0].responses
                    found = True
            if not found:
                valid_predictions = []
                prompt = f"Please answer the following question with one of the alphabetical options provided.\nQuestion: {question}\nAnswer: "
                while len(valid_predictions) < 50:
                    preds = [
                        response.generated_text
                        for response in asyncio.run(batch_generate(prompt))
                    ]
                    for p in preds:
                        p = p.lower().strip()
                        if Response.is_valid_prediciton(p, num_options):
                            valid_predictions.append(p)
                responses = ",".join(valid_predictions)
            pred_row = pd.DataFrame(
                {
                    "key": key,
                    "num_options": num_options,
                    "question": question,
                    "type": column,
                    "responses": responses,
                },
                index=[0],
            )
            all_preds = pd.concat([all_preds, pred_row])
            all_preds.to_pickle(f"checkpoint/{model}/{filename}.pickle")
    return all_preds


def run_zeno(
    bias: str,
    model: str,
    filename: str,
    key: str,
    data: pd.DataFrame,
    all_preds: pd.DataFrame,
    num_samples: int,
    perturbation=None,
    checkpoint_df=None,
) -> List[str]:
    print(f"Key: {key}")
    if not os.path.exists(f"checkpoint/{model}"):
        os.makedirs(f"checkpoint/{model}")
    model_provider = {
        "gpt-3.5-turbo": "openai_chat",
        "gpt-3.5-turbo-instruct": "openai",
    }
    lm_config = LMConfig(provider=model_provider[model], model=model)
    prompt_templates = {
        "openai_chat": (
            "Please answer the following question with one of the alphabetical options provided.\nQuestion: {{text}}\nAnswer: "
        ),
        "openai": (
            "Please answer the following question with one of the alphabetical options provided.\nQuestion: {{text}}\nAnswer: "
        ),
    }
    question_df = data[data["key"] == key]
    print(question_df.head())
    assert len(question_df.index == 1), "question key should be unique"
    columns = Bias.get_col_names(bias)
    for index, row in question_df.iterrows():
        for column in columns:
            print(column)
            question = row[column]
            num_options = row["num options"]
            if Bias.change_num_options(column) and perturbation is None:
                num_options = row["num options new"]
            found = False
            if checkpoint_df is not None:
                saved_row = checkpoint_df.loc[
                    (checkpoint_df["key"] == key) & (checkpoint_df["type"] == column)
                ]
                if len(saved_row) == 1:
                    responses = saved_row.iloc[0]
                    found = True
            if not found:
                valid_responses = []
                while len(valid_responses) < 50:
                    try:
                        predictions = multiple_generate_from_text_prompt(
                            [{"text": question}],
                            prompt_template=prompt_templates[lm_config.provider],
                            model_config=lm_config,
                            temperature=1,
                            max_tokens=1,
                            top_p=1.0,
                            num_responses=100,
                            requests_per_minute=100,
                        )
                    except "OpenAI API Invalid Request":
                        print("not working")
                    pred_list = predictions[0]
                    for prediction in pred_list:
                        prediction = prediction.lower().strip()
                        if Response.is_valid_prediciton(prediction, num_options):
                            valid_responses.append(prediction)
                responses = ",".join(valid_responses)
                print(f"\t{responses}", flush=True)
            pred_row = pd.DataFrame(
                {
                    "key": key,
                    "num_options": num_options,
                    "question": question,
                    "type": column,
                    "responses": responses,
                },
                index=[0],
            )
            all_preds = pd.concat([all_preds, pred_row])
            all_preds.to_pickle(f"checkpoint/{model}/{filename}.pickle")
    return all_preds


def get_predictions(
    bias: str,
    model: str,
    filename: str,
    data_df: pd.DataFrame,
    result_df: pd.DataFrame,
    key: str,
    num_qs=None,
    num_samples=50,
    perturbation=None,
    checkpoint_df=None,
) -> pd.DataFrame:
    all_data = data_df
    all_preds = result_df

    if num_qs is None:
        num_qs = len(all_data.index)
    data = all_data.iloc[:num_qs]

    all_preds = eval(
        f"{model_fxn[model]}(bias, model, filename, key, data, result_df, num_samples, perturbation, checkpoint_df)"
    )

    return all_preds


@click.command()
@click.argument("bias_type")
@click.option("--model", type=str, required=True)
@click.option("--num_qs", type=int, default=None)
@click.option("--num_samples", type=int, default=50)
@click.option("--reduced", type=bool, default=False)
@click.option("--perturbation", type=str, default=None)
@click.option("--checkpoint", type=str, default=None)
def main(
    bias_type: str,
    model: str,
    num_qs: int,
    num_samples: int,
    reduced: bool,
    perturbation: str,
    checkpoint: str,
):
    filename = bias_type
    # load the csv file and read in questions.
    if reduced:
        filename += "-50"
    if perturbation is not None:
        filename += f"-{perturbation}"
    print(filename)
    all_data = pd.read_csv("data/pew_prompts/" + filename + ".csv")

    result_df = pd.DataFrame(
        columns=["key", "num_options", "question", "type", "responses"]
    )

    if checkpoint is not None:
        checkpoint_df = pd.read_pickle(checkpoint)
        print(f"loaded checkpoint from {checkpoint}")
    else:
        checkpoint_df = None

    for key in all_data["key"].unique():
        result_df = get_predictions(
            bias_type,
            model,
            filename,
            all_data,
            result_df,
            key,
            num_qs,
            num_samples,
            perturbation,
            checkpoint_df,
        )

    result_df.reset_index(inplace=True, drop=True)
    print(result_df.head())

    if not os.path.exists(f"results/{model}"):
        os.makedirs(f"results/{model}")
    # save results
    result_df.to_pickle(f"results/{model}/{filename}.pickle")
    print(f"Results saved to results/{model}/{filename}.pickle")


if __name__ == "__main__":
    main()
