import json
from src.classical_methods import rank_unigram_diff
from src.utils import get_similarity_scores
import os
import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
from collections import defaultdict
from sklearn.metrics import roc_auc_score


data_name2system2num_example2performance = {}
all_data_name2system2num_example2performance = {}


for with_goal in [
    True,
    False,
]:

    data_name = "with_goal" if with_goal else "without_goal"
    data_path = f"data/synd5.json"
    with open(data_path, "r") as f:
        data = json.load(f)
    selected = [
        (
            "I want to find out the differences between texts from Group A and Group B."
            not in d["goal"]
        )
        == with_goal
        for d in data
    ]
    data = [d for i, d in enumerate(data) if selected[i]]

    config_keys = ["fraction", "num_descriptions_given", "num_research_text"]

    def get_rating(x):
        if "no" in x.lower():
            return 0
        elif "related" in x.lower():
            return 0
        elif "yes" in x.lower():
            return 1
        else:
            return 0

    def evaluate(path):
        eval_result_path = os.path.join("eval/", data_name + os.path.basename(path))

        if os.path.exists(eval_result_path):
            with open(eval_result_path, "r") as f:
                return json.load(f)

        with open(path, "r") as f:
            results = json.load(f)
            results = [r for i, r in enumerate(results) if selected[i]]

        if len(results) != len(data):
            print("len(results) != len(data)")
            print(len(results), len(data))
            return None

        to_eval_dicts = []
        for data_idx in range(len(results)):
            r = results[data_idx]
            d = data[data_idx]
            hypotheses_info = r["hypotheses_info"]

            gold_A_desc = d["gold_A_description"]
            gold_B_desc = d["gold_B_description"]

            hypotheses_info = sorted(
                hypotheses_info, key=lambda x: x["V_prime"], reverse=True
            )
            pred_A_hypothesis = hypotheses_info[0]["hypothesis"]
            pred_B_hypothesis = hypotheses_info[-1]["hypothesis"]

            A_eval_dict = {
                "data_idx": data_idx,
                "gold": gold_A_desc,
                "pred": pred_A_hypothesis,
                "all": [h["hypothesis"] for h in hypotheses_info],
                "group": "A",
                "V_prime": hypotheses_info[0]["V_prime"],
            }
            A_eval_dict.update({k: d[k] for k in config_keys})
            to_eval_dicts.append(A_eval_dict)

            B_eval_dict = {
                "data_idx": data_idx,
                "gold": gold_B_desc,
                "pred": pred_B_hypothesis,
                "group": "B",
                "all": [h["hypothesis"] for h in hypotheses_info],
                "V_prime": hypotheses_info[-1]["V_prime"],
            }
            B_eval_dict.update({k: d[k] for k in config_keys})
            to_eval_dicts.append(B_eval_dict)

            unigram_features = rank_unigram_diff(d["texts_a"], d["texts_b"])

        hyp_pairs = [{"text_a": d["gold"], "text_b": d["pred"]} for d in to_eval_dicts]
        similarity_ratings = get_similarity_scores(hyp_pairs, model="claude-v1.3")
        for i, d in enumerate(to_eval_dicts):
            d["similarity"] = similarity_ratings[i]

        all_idxes, similarity_to_eval = [], []
        for i, d in enumerate(to_eval_dicts):
            if d["group"] == "A":
                all_hyps = d["all"]
                gold = d["gold"]
                for h in all_hyps:
                    all_idxes.append(i)
                    similarity_to_eval.append({"text_a": gold, "text_b": h})

        similarity_scores = get_similarity_scores(
            similarity_to_eval, model="claude-v1.3"
        )
        for d in to_eval_dicts:
            d["all_similarity_scores"] = []
        for i, s in zip(all_idxes, similarity_scores):
            to_eval_dicts[i]["all_similarity_scores"].append(s)

        with open(eval_result_path, "w") as f:
            json.dump(to_eval_dicts, f, indent=2)

        return to_eval_dicts

    def get_config_tuple(d):
        return tuple((k, d[k]) for k in config_keys)

    def evaluate_run_name(run_name):
        path = os.path.join("results/", run_name)
        if not os.path.exists(path):
            print(f"{path} does not exist")
            return None
        results = evaluate(path)
        if results is None:
            print("results is None")
            return None

        for r in results:
            if r["group"] == "B":
                r["V_prime"] = -r["V_prime"]

        results = [r for r in results if r["group"] == "A"]

        df = pd.DataFrame(results)
        df["rating"] = df["similarity"].apply(get_rating)
        all_ratings = df["all_similarity_scores"].tolist()
        df["all_ratings"] = [
            np.mean([get_rating(r) for r in ratings]) for ratings in all_ratings
        ]

        print(
            f"===================================={run_name}===================================="
        )
        num_examples2acc = {}
        all_num_examples2acc = {}
        avg = df["rating"].mean()
        for k in config_keys:
            # group by k
            if k == "num_descriptions_given":
                num_examples2acc = df.groupby(k)["rating"].mean().to_dict()
                all_num_examples2acc = df.groupby(k)["all_ratings"].mean().to_dict()

        mean_w_2_desc = np.mean(
            [
                r["rating"]
                for r in df[df["num_descriptions_given"] == 2].to_dict("records")
            ]
        )

        print(avg)
        corr = spearmanr(df["rating"], df["V_prime"])
        auc = roc_auc_score(df["rating"], df["V_prime"])

        return {
            "mean_w_2_desc": mean_w_2_desc,
            "num_examples2acc": num_examples2acc,
            "all_num_examples2acc": all_num_examples2acc,
            "avg": avg,
            "df": df,
        }

    run_names = os.listdir("results/")

    acc = defaultdict(dict)
    for valdiator_name in [
        "d5_t5_validator_3B",
    ]:
        for propose_name in ["text-davinci-003", "gpt-4"]:
            run_name = f"d5_{valdiator_name}_{propose_name}_orig_1_5_varied_fraction_synd5_all_problems.json"
            r = evaluate_run_name(run_name)
            if r is None:
                continue
            acc[valdiator_name][propose_name] = r["mean_w_2_desc"]
            abbr_data_name = "w/ g" if "with_goal" in data_name else "wo/ g"
            if abbr_data_name not in data_name2system2num_example2performance:
                data_name2system2num_example2performance[abbr_data_name] = {}
                all_data_name2system2num_example2performance[abbr_data_name] = {}

            data_name2system2num_example2performance[abbr_data_name][propose_name] = r[
                "num_examples2acc"
            ]
            all_data_name2system2num_example2performance[abbr_data_name][
                propose_name
            ] = r["all_num_examples2acc"]

    acc = pd.DataFrame(acc)


def avg_vals(ds):
    return {k: {k_: (v_[2] + v_[0]) / 2 for k_, v_ in v.items()} for k, v in ds.items()}


# print the results using the validator
df = pd.DataFrame(avg_vals(data_name2system2num_example2performance))
print(df.round(2))

# print the results without the validator
df = pd.DataFrame(avg_vals(all_data_name2system2num_example2performance))
print(df.round(2))
