import csv
from pathlib import Path

import plotly.graph_objects as go

from src.path import task_stats_dir


question_error_categories = ["correct", "year constraints", "insufficient information", "number", "year / date", "noun", "proper noun", "modifier"]

if __name__ == "__main__":
    with open(task_stats_dir / "answerability_classification_mistakes.csv", "r") as f:
        reader = csv.DictReader(f)
        annotations = [row["error_category_annotation"] for row in reader]
    
    count_list = [0, 0] + [annotations.count(category) for category in question_error_categories]
    
    fig = go.Figure(
        go.Sunburst(
            labels =  ["Unanswerable", "Wrong Information", "Answerable", " <br><br>Future           <br>Information              ", " <br><br><br>Insufficient<br>Information",            "Number", "Year/Date",              "Noun", "Proper Noun",          "Modifier"],
            parents = [            "",      "Unanswerable",           "",          "Unanswerable",                "Unanswerable", "Wrong Information",    "Number", "Wrong Information",        "Noun", "Wrong Information"],
            values = count_list,
            insidetextorientation = "horizontal",
            marker=dict(colors=['cornflowerblue',  "dodgerblue",  'salmon',            'dodgerblue',                  'lightgray',        'deepskyblue',   'skyblue',       'deepskyblue',    'skyblue', 'deepskyblue']), # Example sector colors
            textfont = dict(color='black'),
        )
    )
    
    # tight layout
    fig.update_layout(
        uniformtext=dict(minsize=16, mode='show'),
        margin=dict(t=0, l=0, r=0, b=0)
    )
    
    # save
    fig.write_image(task_stat_dir / "answerability_mistakes_stats.pdf")
