import argparse 
import numpy as np 
import openai
import pandas as pd 
import pickle
import time
import torch 

from pathlib import Path 
from torch.utils.data import DataLoader 

from mi_estimation_utils import load_data, MyDataset
from openai_gpt import make_gpt3_query


def gpt_score(dataset_name: str, ds: MyDataset) -> pd.DataFrame:
    """
    Compute some scores for each row in the esnli data. Save to the dataframe as standalone columns.
    """
    labels2str = {"esnli": ["contradiction", "entailment", "neutral"]}[dataset_name]

    prompt_template = """Following are two sentences, a label and an explanation. \nThe two sentences are: {}\nThe label is: {}\nThe explanation is {}\nPlease use one of 'strongly disagree', 'somewhat disagree', 'somewhat agree' and 'strongly agree' to describe your attitude towards the following statement: {}. Do not add additional words."""

    statement_bank = {
        "informativeness": "The explanation provides sufficient information to support how the two sentences are associated to the label.",
        "causal_support": "The explanation explains why these two sentences are associated to the label.",
        "convincingness": "The explanation is persuasive, and convinces me to believe that the question is associated to the label.",
        "coherence": "The explanation bridges the gap between the two sentences and the label, in a coherent and unsuprising manner.",
        "label_relevance": "Given the two sentences and the label, the explanation is relevant.",
        "input_relevance": "Given the two sentences, the explanation is relevant.",
        "clarity4student": "The explanation is easy to understand for a high school student.",
        "clarity4graduate": "The explanation is easy to understand for a university graduate.",
        "importance": "Ths explanation highlights the most important parts in the two sentences that associate to the label."
    }

    scores = {k: [] for k in statement_bank.keys()}
    scores["Xs"] = []
    scores["Ys"] = []
    scores["Es"] = []

    for i in range(len(ds)):
        Xs, Ys, Es = ds[i]["X"], labels2str[ds[i]["Y"]], ds[i]["E"]
        scores["Xs"].append(Xs)
        scores["Ys"].append(Ys)
        scores["Es"].append(Es)
        for aspect in statement_bank:
            prompt = prompt_template.format(Xs, Ys, Es, statement_bank[aspect])
            try:
                answer = make_gpt3_query(prompt, max_tokens=7)
            except openai.error.APIError as e:
                print("API error! Waiting for 5 minutes before retrying. Following is the error msg:", e)
                time.sleep(300)
                answer = make_gpt3_query(prompt, max_tokens=7)
            answer = answer.strip().strip(".").lower()
            time.sleep(0.02)  # clamp to <3000 RPM
            scores[aspect].append(answer)

    return pd.DataFrame(scores)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, choices=["rationale", "nle"], default="nle")
    parser.add_argument("--dataset", type=str, default="esnli")
    parser.add_argument("--embedding", type=str, default="openai", help="Not used actually; just need this argument to load the prepared data pkl")
    parser.add_argument("--downsample", type=int, default=1200)
    parser.add_argument("--report_dir", type=str, default="../data/scored/")
    args = parser.parse_args()

    print(args)

    Path(args.report_dir).mkdir(parents=True, exist_ok=True)
    train_ds, val_ds, test_ds = load_data(args, rawtext=True)
    for split, ds in zip(["test"], [test_ds]):
        start_time = time.time()
        report_df = gpt_score(args.dataset, ds)
        report_df.to_csv(Path(args.report_dir, f"{args.dataset}_{args.method}_{args.downsample}_{split}.csv"), index=False)
        print("{} split done. Time slapsed: {:.4f} seconds".format(
            split, time.time() - start_time))