from data import CausalDataset, MoralDataset, Example, JsonSerializable
from adapter import GPT3Adapter, HuggingfaceAdapter, DelphiAdapter
from evaluator import AccuracyEvaluator, AccuracyEvaluatorWithCategories, FactorExperimentResult
from prompt import CausalJudgmentPrompt, MoralJudgmentPrompt, CausalAbstractJudgmentPrompt, MoralAbstractJudgmentPrompt, \
                    CausalFactorPrompt, MoralFactorPrompt
from thought_as_text_translator import MoralTranslator, CausalTranslator
from tqdm import tqdm
import pickle
import json
from collections import defaultdict

from dataclasses import dataclass

@dataclass
class FullExperimentResult(JsonSerializable):
    weighted_acc: float
    weighted_acc_conf_interval: tuple[float, float]
    category_results: dict[str, FactorExperimentResult]

def exp3_causal_factor(engine:str='text-davinci-002', save:bool=True) -> FullExperimentResult:
    cd = CausalDataset()
    evaluator = AccuracyEvaluator()
    category_evaluator = AccuracyEvaluatorWithCategories()
    adapter = GPT3Adapter(engine=engine)

    cfp = CausalFactorPrompt(anno_utils=cd.anno_utils)

    all_instances, all_choice_scores, all_label_indices = [], [], []
    factor_to_instances_map = defaultdict(list)  # {'factor1': [1,2,3], 'factor2': [4,5,6]}
    ex: Example
    for i, ex in tqdm(enumerate(cd), total=len(cd)):
        factor_categories, factor_instances = cfp.apply(ex)
        for j, instance in enumerate(factor_instances):
            choice_scores = adapter.adapt(instance, method='multiple_choice')

            all_choice_scores.append(choice_scores)
            all_label_indices.append(instance.answer_index)
            all_instances.append(instance)

            factor = factor_categories[j]
            factor_to_instances_map[factor].append(len(all_instances) - 1)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    category_accs = category_evaluator.evaluate(all_choice_scores, all_label_indices, factor_to_instances_map)

    if save:
        # using pickle to save all*
        with open(f"../../results/factor_preds/exp3_{engine}_causal_factor_preds.pkl", "wb") as f:
            pickle.dump((all_instances, all_choice_scores, all_label_indices), f)

    print()
    print(f"Causal Factor Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    for factor, acc_with_ci in category_accs.items():
        print(f"{factor}: {acc_with_ci.acc:.4f} ({acc_with_ci.conf_interval[0]:.4f}, {acc_with_ci.conf_interval[1]:.4f})")

    return FullExperimentResult(acc, conf_interval, category_accs)

def exp3_moral_factor(engine:str='text-davinci-002', save:bool=True) -> FullExperimentResult:
    md = MoralDataset()
    evaluator = AccuracyEvaluator()
    category_evaluator = AccuracyEvaluatorWithCategories()
    adapter = GPT3Adapter(engine=engine)

    mfp = MoralFactorPrompt(anno_utils=md.anno_utils)

    all_instances, all_choice_scores, all_label_indices = [], [], []
    factor_to_instances_map = defaultdict(list)  # {'factor1': [1,2,3], 'factor2': [4,5,6]}
    ex: Example
    for i, ex in tqdm(enumerate(md), total=len(md)):
        factor_categories, factor_instances = mfp.apply(ex)
        for j, instance in enumerate(factor_instances):
            choice_scores = adapter.adapt(instance, method='multiple_choice')

            all_choice_scores.append(choice_scores)
            all_label_indices.append(instance.answer_index)
            all_instances.append(instance)

            factor = factor_categories[j]
            factor_to_instances_map[factor].append(len(all_instances) - 1)

    acc, conf_interval = evaluator.evaluate(all_choice_scores, all_label_indices)
    category_accs = category_evaluator.evaluate(all_choice_scores, all_label_indices, factor_to_instances_map)

    if save:
        # using pickle to save all*
        with open(f"../../results/factor_preds/exp3_{engine}_moral_factor_preds.pkl", "wb") as f:
            pickle.dump((all_instances, all_choice_scores, all_label_indices), f)

    print()
    print(f"Moral Factor Accuracy: {acc:.4f} ({conf_interval[0]:.4f}, {conf_interval[1]:.4f})")
    for factor, acc_with_ci in category_accs.items():
        print(f"{factor}: {acc_with_ci.acc:.4f} ({acc_with_ci.conf_interval[0]:.4f}, {acc_with_ci.conf_interval[1]:.4f})")

    return FullExperimentResult(acc, conf_interval, category_accs)

def produce_table3():
    result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        fer = exp3_moral_factor(engine=engine, save=True)
        result[engine] = fer.json

    json.dump(result, open('../../results/exp3_moral_full_result.json', 'w'), indent=2)

    result = {}
    for engine in ["text-babbage-001", 'text-curie-001', 'text-davinci-002']:
        fer = exp3_causal_factor(engine=engine, save=True)
        result[engine] = fer.json

    json.dump(result, open('../../results/exp3_causal_full_result.json', 'w'), indent=2)

if __name__ == '__main__':
    ...
    # fer = exp3_causal_factor()
    # open a new file and save fer to json
    # json.dump(fer.json, open("../../results/exp3_text-davinci-002_causal_factor_result.json", "w"))

    # fer = exp3_moral_factor()
    # open a new file and save fer to json
    # json.dump(fer.json, open("../../results/exp3_text-davinci-002_moral_factor_result.json", "w"))

    produce_table3()