import json
from dataclasses import dataclass
from typing import List, Dict

@dataclass
class ExperimentResult:
    acc: float
    conf_interval: tuple[float, float]
    r: float
    p: float
    rmse: float
    auroc: float

@dataclass
class FactorExperimentResult:
    acc: float
    conf_interval: tuple[float, float]

@dataclass
class FullExperimentResult:
    weighted_acc: float
    weighted_acc_conf_interval: tuple[float, float]
    category_results: dict[str, FactorExperimentResult]

def load_results(path: str) -> Dict[str, ExperimentResult]:
    result = {}
    with open(path, 'r') as f:
        data = json.load(f)
        for k, v in data.items():
            result[k] = ExperimentResult(**eval(v))
    return result

def load_factor_results(path: str) -> Dict[str, FullExperimentResult]:
    result = {}
    with open(path, 'r') as f:
        data = json.load(f)
        for k, v in data.items():
            result[k] = FullExperimentResult(**eval(v))
    return result

def display_table1():
    # Load results
    causal_results = load_results('exp1_causal_full_result.json')
    # print csv
    print("Causal Table")
    print('engine,acc,r,rmse,auroc')
    for engine, result in causal_results.items():
        print(f'{engine},{result.acc*100:.1f} $\pm$ {(result.conf_interval[1] - result.acc)*100:.1f},{result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}')

    moral_results = load_results('exp1_moral_full_result.json')
    print()
    print("Moral Table")
    print('engine,acc,r,rmse,auroc')
    for engine, result in moral_results.items():
        print(f'{engine},{result.acc*100:.1f} $\pm$ {(result.conf_interval[1] - result.acc)*100:.1f},{result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}')

def display_table2():
    # Load results
    causal_results = load_results('exp2_causal_full_result.json')
    # print csv
    print("Causal Table")
    print('engine,acc,r,rmse,auroc')
    for engine, result in causal_results.items():
        print(f'{engine},{result.acc*100:.1f} $\pm$ {(result.conf_interval[1] - result.acc)*100:.1f},{result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}')

    moral_results = load_results('exp2_moral_full_result.json')
    print()
    print("Moral Table")
    print('engine,acc,r,rmse,auroc')
    for engine, result in moral_results.items():
        print(f'{engine},{result.acc*100:.1f} $\pm$ {(result.conf_interval[1] - result.acc)*100:.1f},{result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}')

def display_table3():
    # load results
    causal_results = load_factor_results('exp3_causal_full_result.json')
    order = ['action_omission', 'norm_type', 'time', 'agent_awareness', 'causal_structure', 'event_normality']
    translation_dict = ['Action Omission', 'Norm Type', 'Time', 'Agent Awareness', 'Causal Structure', 'Event Normality']
    header = ['Model', 'Weighted Average'] + translation_dict

    # print csv
    print("Causal Table")
    print(",".join(header))
    for engine, result in causal_results.items():
        row = [engine] + [f"{result.weighted_acc*100:.1f} $\pm$ {(result.weighted_acc_conf_interval[1] - result.weighted_acc) * 100:.1f}"]
        for category in order:
            row.append(f"{result.category_results[category]['acc']*100:.1f} $\pm$ {(result.category_results[category]['conf_interval'][1] - result.category_results[category]['acc'])*100:.1f}")
        print(",".join(row))

    # for moral
    # load results
    moral_results = load_factor_results('exp3_moral_full_result.json')
    order = ['personal_force', 'beneficiary', 'locus_of_intervention', 'causal_role', 'evitability']
    translation_dict = ['Personal Force', 'Beneficience', 'Locus of Intervention', 'Causal Role', 'Counterfactual Evitability']
    header = ['Model', 'Weighted Average'] + translation_dict

    # print csv
    print()
    print("Moral Table")
    print(",".join(header))
    for engine, result in moral_results.items():
        row = [engine] + [f"{result.weighted_acc*100:.1f} $\pm$ {(result.weighted_acc_conf_interval[1] - result.weighted_acc) * 100:.1f}"]
        for category in order:
            row.append(f"{result.category_results[category]['acc']*100:.1f} $\pm$ {(result.category_results[category]['conf_interval'][1] - result.category_results[category]['acc'])*100:.1f}")
        print(",".join(row))

def display_table4():
    r1_causal_results = load_results('exp1_causal_full_result.json')
    causal_results = load_results('exp4_causal_ttt_result.json')
    print("Causal TTT Table")
    print('engine,acc,Delta')
    for engine, result in causal_results.items():
        # {result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}
        print(f'{engine},{result.acc * 100:.1f} $\pm$ {(result.conf_interval[1] - result.acc) * 100:.1f},{(result.acc - r1_causal_results[engine].acc) * 100:.1f}')

    causal_results = load_results('exp4_causal_ere_result.json')
    print()
    print("Causal ERE Table")
    print('engine,acc,Delta')
    for engine, result in causal_results.items():
        # {result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}
        print(
            f'{engine},{result.acc * 100:.1f} $\pm$ {(result.conf_interval[1] - result.acc) * 100:.1f},{(result.acc - r1_causal_results[engine].acc) * 100:.1f}')

    r1_moral_results = load_results('exp1_moral_full_result.json')
    moral_results = load_results('exp4_moral_ttt_result.json')
    print()
    print("Moral TTT Table")
    print('engine,acc,Delta')
    for engine, result in moral_results.items():
        # {result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}
        print(
            f'{engine},{result.acc * 100:.1f} $\pm$ {(result.conf_interval[1] - result.acc) * 100:.1f},{(result.acc - r1_moral_results[engine].acc) * 100:.1f}')

    moral_results = load_results('exp4_moral_ere_result.json')
    print()
    print("Moral ERE Table")
    print('engine,acc,Delta')
    for engine, result in moral_results.items():
        # {result.r:.2f},{result.rmse:.2f},{result.auroc:.2f}
        print(
            f'{engine},{result.acc * 100:.1f} $\pm$ {(result.conf_interval[1] - result.acc) * 100:.1f},{(result.acc - r1_moral_results[engine].acc) * 100:.1f}')

if __name__ == '__main__':
    ...
    # display_table1()
    # display_table2()
    # display_table3()
    display_table4()