import argparse
import pandas as pd

DEFAULT_CATEGORY_MAPPING = {
    "Spatial perception": ["Pyramid"],
    "Spatial orientation": ["Agent Sight", "Sun Direction"],
    "Mental objects rotation": ["Revolution", "Unfolded", "Pyramid"],
    "Spatial visualization": ["Polyomino", "Full Views"],
}


def count_difficulty_by_category(dataset_csv):
    df = pd.read_csv(dataset_csv)

    category_counts = {}

    for category, scenes in DEFAULT_CATEGORY_MAPPING.items():
        category_data = df[df['SceneName'].isin(scenes)]

        difficulty_counts = category_data['Difficulty'].value_counts()

        easy_count = difficulty_counts.get('easy', 0)
        medium_count = difficulty_counts.get('medium', 0)
        hard_count = difficulty_counts.get('hard', 0)

        category_counts[category] = {
            'easy': easy_count,
            'medium': medium_count,
            'hard': hard_count
        }

    return category_counts


def generate_tikz_code(category_counts):
    categories = list(category_counts.keys())

    short_names = {
        "Spatial perception": "SP",
        "Spatial orientation": "SO",
        "Mental objects rotation": "MOR",
        "Spatial visualization": "SV"
    }

    tikz_code = []
    tikz_code.append("\\begin{tikzpicture}")
    tikz_code.append("\\begin{axis}[")
    tikz_code.append("    ybar stacked,")
    tikz_code.append("    bar width=16pt,")
    tikz_code.append("    ymin=0, ymax=300,")
    tikz_code.append("    ylabel={Count},")
    tikz_code.append("    width=0.95\\linewidth,")
    tikz_code.append("    height=0.6\\linewidth,")
    tikz_code.append("    enlargelimits=0.15,")
    tikz_code.append("    axis x line=bottom,")
    tikz_code.append("    axis y line=left,")
    tikz_code.append(
        f"    symbolic x coords={{{', '.join(short_names.values())}}},")
    tikz_code.append("    xtick=data,")
    tikz_code.append("    tick label style={font=\\scriptsize},")
    tikz_code.append("    xticklabel style={rotate=60, anchor=north east},")
    tikz_code.append("    legend style={")
    tikz_code.append("        at={(0.01,0.99)},")
    tikz_code.append("        anchor=north west,")
    tikz_code.append("        font=\\scriptsize,")
    tikz_code.append("        draw=none,")
    tikz_code.append("        fill=white,")
    tikz_code.append("        fill opacity=0.8")
    tikz_code.append("    },")
    tikz_code.append("    legend cell align={left},")
    tikz_code.append("    nodes near coords,")
    tikz_code.append("    nodes near coords align={center},")
    tikz_code.append(
        "    every node near coord/.append style={font=\\tiny,text=black},")
    tikz_code.append("    point meta=explicit symbolic,")
    tikz_code.append("]")
    tikz_code.append("")

    tikz_code.append("\\addplot+[ybar, fill=green!30] coordinates {")
    easy_coords = []
    for category in categories:
        short_name = short_names[category]
        easy_count = category_counts[category]['easy']
        if easy_count > 0:
            easy_coords.append(f"  ({short_name},{easy_count}) [{easy_count}]")
        else:
            easy_coords.append(f"  ({short_name},{easy_count})")
    tikz_code.append("\n".join(easy_coords))
    tikz_code.append("};")
    tikz_code.append("\\addlegendentry{Easy}")
    tikz_code.append("")

    tikz_code.append("\\addplot+[ybar, fill=yellow!40] coordinates {")
    medium_coords = []
    for category in categories:
        short_name = short_names[category]
        medium_count = category_counts[category]['medium']
        if medium_count > 0:
            medium_coords.append(f"  ({short_name},{medium_count}) [{medium_count}]")
        else:
            medium_coords.append(f"  ({short_name},{medium_count})")
    tikz_code.append("\n".join(medium_coords))
    tikz_code.append("};")
    tikz_code.append("\\addlegendentry{Medium}")
    tikz_code.append("")

    tikz_code.append("\\addplot+[ybar, fill=red!30] coordinates {")
    hard_coords = []
    for category in categories:
        short_name = short_names[category]
        hard_count = category_counts[category]['hard']
        if hard_count > 0:
            hard_coords.append(f"  ({short_name},{hard_count}) [{hard_count}]")
        else:
            hard_coords.append(f"  ({short_name},{hard_count})")
    tikz_code.append("\n".join(hard_coords))
    tikz_code.append("};")
    tikz_code.append("\\addlegendentry{Hard}")
    tikz_code.append("")

    tikz_code.append("\\end{axis}")
    tikz_code.append("\\end{tikzpicture}")

    return "\n".join(tikz_code)


def main():
    parser = argparse.ArgumentParser(
        description="Generate TikZ code for difficulty distribution plot")
    parser.add_argument(
        "--dataset-csv",
        type=str,
        required=True,
        help="Path to dataset CSV with difficulty labels"
    )

    args = parser.parse_args()

    category_counts = count_difficulty_by_category(args.dataset_csv)
    tikz_code = generate_tikz_code(category_counts)
    print(tikz_code)


if __name__ == "__main__":
    main()
