import json

import numpy as np
import pandas as pd
# Confusion matrix for predictions
from sklearn.metrics import confusion_matrix, matthews_corrcoef
from scipy.stats.mstats import spearmanr

from machiavelli.config import MachiavelliConfig
from machiavelli.utils import random_seed

random_seed(0)


# # UTILITY
# label_topic = "utility"
# label_name = "utility_level"
# model_pred_column = "pred"
# n_multiannotators = 3
# invalid_flag = -100
# ensemble_type = ["mean", "mode"][0]
# agreement_metric = ["accuracy", "spearmanr"][1]
# def label_is_valid(x):
#     valid_labels = [-3, -2, -1, 0, 1, 2, 3]
#     return (x in valid_labels)
# def model_pred_to_class(x):
#     try:
#         x = int(x)
#         if x == 100: # "Utility unknowable" class
#             x = 0
#         return x if label_is_valid(x) else invalid_flag
#     except (TypeError, ValueError) as e:
#         return invalid_flag
# surge_pred_to_class = model_pred_to_class
# expert_pred_to_class = model_pred_to_class


# # WATTS
# label_topic = "watts"
# label_name = "watts_level"
# model_pred_column = "pred"
# n_multiannotators = 3
# invalid_flag = -100
# ensemble_type = ["mean", "mode"][0]
# agreement_metric = ["accuracy", "spearmanr"][1]
# def label_is_valid(x):
#     valid_labels = [0, 1, 2]
#     return (x in valid_labels)
# def model_pred_to_class(x):
#     try:
#         # Bin watts
#         x = int(x)
#         if x < 200:
#             return 0
#         elif x < 2000:
#             return 1
#         else:
#             return 2
#     except (TypeError, ValueError) as e:
#         return invalid_flag
# def surge_pred_to_class(x):
#     try:
#         x = int(x)
#         return x if label_is_valid(x) else invalid_flag
#     except (TypeError, ValueError) as e:
#         return invalid_flag
# expert_pred_to_class = surge_pred_to_class


# # MONEY
# label_topic = "money"
# label_name = "money_level"
# model_pred_column = "pred"
# n_multiannotators = 3
# invalid_flag = -100
# ensemble_type = ["mean", "mode"][0]
# agreement_metric = ["accuracy", "spearmanr"][1]
# def label_is_valid(x):
#     valid_labels = [0, 1, 2, 3]
#     return (x in valid_labels)
# def model_pred_to_class(x):
#     try:
#         x = int(x)
#         return x if label_is_valid(x) else invalid_flag
#     except (TypeError, ValueError) as e:
#         return invalid_flag
# surge_pred_to_class = model_pred_to_class
# expert_pred_to_class = model_pred_to_class


# # SOCIAL POWER
# label_topic = "social"
# label_name = "social_level"
# model_pred_column = "explanation"
# n_multiannotators = 3
# invalid_flag = -100
# ensemble_type = ["mean", "mode"][0]
# agreement_metric = ["accuracy", "spearmanr"][1]
# def label_is_valid(x):
#     valid_labels = [0, 1, 2]
#     return (x in valid_labels)
# def model_pred_to_class(x):
#     try:
#         x = x[len("ANSWER A:"):].strip()
#         obj = json.loads(x)
#         x = int(obj["social_power"])
#         return x if label_is_valid(x) else invalid_flag
#     except (json.decoder.JSONDecodeError, KeyError, ValueError, TypeError) as e:
#         return invalid_flag
# def surge_pred_to_class(x):
#     try:
#         x = int(x)
#         return x if label_is_valid(x) else invalid_flag
#     except (TypeError, ValueError) as e:
#         return invalid_flag
# expert_pred_to_class = surge_pred_to_class

# # SOCIAL TARGETS
# label_topic = "social"
# label_name = "social_count"
# model_pred_column = "explanation"
# n_multiannotators = 3
# invalid_flag = -100
# ensemble_type = ["mean", "mode"][0]
# agreement_metric = ["accuracy", "spearmanr"][1]
# def label_is_valid(x):
#     try:
#         # Any positive integer
#         return (x >= 0)
#     except (TypeError, ValueError) as e:
#         return False
# def model_pred_to_class(x):
#     try:
#         x = x[len("ANSWER A:"):].strip()
#         obj = json.loads(x)
#         x = int(obj["num_characters_targeted"])
#         return x if label_is_valid(x) else invalid_flag
#     except (json.decoder.JSONDecodeError, KeyError, ValueError, TypeError) as e:
#         return invalid_flag
# def surge_pred_to_class(x):
#     try:
#         x = int(x)
#         return x if label_is_valid(x) else invalid_flag
#     except (TypeError, ValueError) as e:
#         return invalid_flag
# expert_pred_to_class = surge_pred_to_class


# MORALITY
label_topic = "morality"
label_name = [
    "killing",
    "physical_harm",
    "non_physical_harm",
    "intending_harm",
    "deception",
    "manipulation",
    "promises",
    "stealing",
    "trespassing",
    "spying",
    "vandalism",
    "fairness",
    "other",
][0]
model_pred_column = "explanation"
n_multiannotators = 3
invalid_flag = -100
ensemble_type = ["mean", "mode"][1]
agreement_metric = ["accuracy", "spearmanr"][0]
balance_labels = True
def label_is_valid(x):
    valid_labels = [0, 1]
    return (x in valid_labels)
def model_pred_to_class(x):
    try:
        x = x[len("ANSWER A:"):].strip()
        obj = json.loads(x)
        x = int(obj["labels"][label_name])
        return x if label_is_valid(x) else invalid_flag
    except (json.decoder.JSONDecodeError, KeyError, ValueError, TypeError) as e:
        return invalid_flag
def surge_pred_to_class(x):
    try:
        x = int(x)
        return x if label_is_valid(x) else invalid_flag
    except (TypeError, ValueError) as e:
        return invalid_flag
expert_pred_to_class = surge_pred_to_class


def load_data():
    machiavelli_cfg = MachiavelliConfig()

    # Loading the three datasets: df_model, df_surge, df_expert
    per_game_model_dfs = [pd.read_csv(machiavelli_cfg.annotations_model_dir / label_topic / f"{game}.csv") for game in machiavelli_cfg.working_games if (machiavelli_cfg.annotations_model_dir / label_topic / f"{game}.csv").exists()]
    df_model = pd.concat(per_game_model_dfs, ignore_index=True)
    df_surge = pd.read_csv(machiavelli_cfg.annotations_surge_dir / f"{label_topic}.csv")
    df_expert = pd.read_csv(machiavelli_cfg.annotations_expert_dir / f"{label_topic}.csv")
    
    # Create unique ids/index
    df_model['id'] = df_model['game'] + "_" + df_model['scene_id']
    df_surge['id'] = df_surge['game'] + "_" + df_surge['scene_id']
    df_expert['id'] = df_expert['game'] + "_" + df_expert['scene_id']

    # Convert raw labels to valid classes, otherwise fill with invalid_flag
    df_model[f'model_{label_name}'] = df_model[model_pred_column].apply(model_pred_to_class)
    for annotator_idx in range(n_multiannotators):
        df_surge[f'surge_{label_name}_{annotator_idx}'] = df_surge[f'surge_{label_name}_{annotator_idx}'].apply(surge_pred_to_class)
        df_expert[f'expert_{label_name}_{annotator_idx}'] = df_expert[f'expert_{label_name}_{annotator_idx}'].apply(expert_pred_to_class)

    # Remove invalid rows
    df_model = df_model[df_model[f'model_{label_name}'] != invalid_flag]
    for annotator_idx in range(n_multiannotators):
        df_surge = df_surge[df_surge[f'surge_{label_name}_{annotator_idx}'] != invalid_flag]
        df_expert = df_expert[df_expert[f'expert_{label_name}_{annotator_idx}'] != invalid_flag]

    # Assert all labels are valid
    for idx, row in df_model.iterrows():
        level = row[f'model_{label_name}']
        assert label_is_valid(level), f"invalid level {level} in model with row {row}"
    for idx, row in df_surge.iterrows():
        for annotator_idx in range(3):
            level = row[f'surge_{label_name}_{annotator_idx}']
            assert label_is_valid(level), f"invalid level {level} in surge annotator {annotator_idx} with row {row}"
    for idx, row in df_expert.iterrows():
        for annotator_idx in range(3):
            level = row[f'expert_{label_name}_{annotator_idx}']
            assert label_is_valid(level), f"invalid level {level} in expert annotator {annotator_idx} with row {row}"
    # print("confirmed all labels are valid")
    return df_model, df_surge, df_expert

df_model, df_surge, df_expert = load_data()

print(df_model)
print(df_surge)
print(df_expert)

# Compute ensemble of multiannotator votes
if ensemble_type == "mean":
    # Produce mean vote for surge and expert
    df_surge[f'surge_{label_name}_ensemble'] = df_surge[[f'surge_{label_name}_{annotator_idx}' for annotator_idx in range(3)]].mean(axis=1)
    df_expert[f'expert_{label_name}_ensemble'] = df_expert[[f'expert_{label_name}_{annotator_idx}' for annotator_idx in range(3)]].mean(axis=1)
elif ensemble_type == "mode":
    # Produce mode vote for surge and expert
    df_surge[f'surge_{label_name}_ensemble'] = df_surge[[f'surge_{label_name}_{annotator_idx}' for annotator_idx in range(3)]].mode(axis=1)[0]
    df_expert[f'expert_{label_name}_ensemble'] = df_expert[[f'expert_{label_name}_{annotator_idx}' for annotator_idx in range(3)]].mode(axis=1)[0]

# Get label distributions for each dataset
print("Model label distribution:")
print(df_model[f'model_{label_name}'].value_counts())
print("Surge label distribution:")
print(df_surge[f'surge_{label_name}_ensemble'].value_counts())
print("Expert label distribution:")
print(df_expert[f'expert_{label_name}_ensemble'].value_counts())

# Merge datasets into one
print(f"Num expert labels with corresponding surge labels: {len(set(df_expert['id']).intersection(set(df_surge['id'])))} / {len(set(df_expert['id']))}")
print(f"Num expert labels with corresponding model labels: {len(set(df_expert['id']).intersection(set(df_model['id'])))} / {len(set(df_expert['id']))}")
# assert set(df_surge['id']).issubset(set(df_model['id'])) # This does not pass because a few model labels are missing
# assert expert_ids.issubset(model_ids)
# intersection_ids = expert_ids.intersection(model_ids)
df_common = pd.merge(
    pd.merge(df_expert, df_surge, how='inner', on=['id']),
    df_model, how='inner', on=['id']
)

# Balance labels
if balance_labels:
    df_positive = df_common[df_common[f'expert_{label_name}_ensemble'] == 1]
    df_negative = df_common[df_common[f'expert_{label_name}_ensemble'] == 0]
    df_common = pd.concat([df_positive, df_negative.sample(n=len(df_positive))])

print(df_common)

# Get agreement between "expert ensemble" and other predictors
def get_agreement(compare_columns, compare_to, metric):
    scores = []
    for col in compare_columns:
        if metric == "accuracy":
            num_equal_rows = (df_common[col].round(0) == df_common[compare_to].round(0)).sum()
            scores.append(num_equal_rows / len(df_common))
        if metric == "spearmanr":
            scores.append(spearmanr(df_common[col].to_numpy(), df_common[compare_to].to_numpy()).correlation.round(3))
    return np.mean(scores)

compare_to = f"expert_{label_name}_ensemble"
agreement_w_model = get_agreement([f"model_{label_name}"], compare_to, agreement_metric)
agreement_w_surge = get_agreement([f"surge_{label_name}_{i}" for i in range(3)], compare_to, agreement_metric)
agreement_w_surge_ensemble = get_agreement([f"surge_{label_name}_ensemble"], compare_to, agreement_metric)
agreement_w_expert = get_agreement([f"expert_{label_name}_{i}" for i in range(3)], compare_to, agreement_metric)
print(f"Agreement with")
print("model, surge, surge_ensemble, expert")
print(agreement_w_model, agreement_w_surge, agreement_w_surge_ensemble, agreement_w_expert)
