"""
Compute the agreement between the human-generated and the AI-generated reviews, divided by year.
"""
import sqlite3

import numpy as np
from matplotlib import pyplot as plt

from config import DB_PATH, NEGATIVE_COLOR, POSITIVE_COLOR
from dataset_metrics.extract_info_from_reviews import POSITIVE_SCORES

REJECT_DECISIONS = [
    "Reject",
    "Desk Reject",
]

ACCEPTANCE_DECISIONS = [
    "Accept (Poster)",
    "Invite to Workshop Track",
    "Accept (Oral)",
    "Accept (Spotlight)",
    "Accept (Talk)",
    "Accept: poster",
    "Accept: notable-top-5%",
    "Accept: notable-top-25%",
    "Accept (poster)",
    "Accept (spotlight)",
    "Accept (oral)",
]


def get_review_agreement():
    query_llm_reviews_for_paper = """
        select rating, decision, when_submitted 
        from genai_review join submission on paper_id = id 
        where type='neutral' and decision != "Withdrawn" and rating != "Score not found"
    """

    agreement_per_year = {
        k: 0 for k in range(2018, 2026)
    }
    disagreement_per_year = {
        k: 0 for k in range(2018, 2026)
    }

    with sqlite3.connect(str(DB_PATH)) as connection:
        ratings = connection.execute(query_llm_reviews_for_paper).fetchall()
    for (genai_rating, decision, year) in ratings:
        positive_decision = decision in ACCEPTANCE_DECISIONS
        genai_positive_rating = genai_rating in POSITIVE_SCORES
        if positive_decision == genai_positive_rating:
            agreement_per_year[year] += 1
        else:
            disagreement_per_year[year] += 1
    return agreement_per_year, disagreement_per_year


if __name__ == '__main__':
    agreement_per_year, disagreement_per_year = get_review_agreement()
    plt.figure()
    years = list(range(2018, 2026))
    x = np.arange(len(years))
    width = 0.2
    n_columns = 2
    plt.bar(x, [agreement_per_year[k] for k in years], width=width, label='Agree', color=POSITIVE_COLOR)
    plt.bar(x + width, [disagreement_per_year[k] for k in years], width=width, label='Disagree', color=NEGATIVE_COLOR)
    plt.xticks(x + width * (n_columns - 1) / 2, [str(y) for y in years], rotation=45)
    plt.legend()
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.30), ncol=2)
    plt.tight_layout()
    plt.show()
