"""
Compute the agreement between the human-generated and the AI-generated reviews.
"""

import sqlite3

import numpy as np
from matplotlib import pyplot as plt

from config import DB_PATH, POSITIVE_COLOR, NEGATIVE_COLOR
from dataset_metrics.extract_info_from_reviews import POSITIVE_SCORES

REJECT_DECISIONS = [
    "Reject",
    "Desk Reject",
]

ACCEPTANCE_DECISIONS = [
    "Accept (Poster)",
    "Invite to Workshop Track",
    "Accept (Oral)",
    "Accept (Spotlight)",
    "Accept (Talk)",
    "Accept: poster",
    "Accept: notable-top-5%",
    "Accept: notable-top-25%",
    "Accept (poster)",
    "Accept (spotlight)",
    "Accept (oral)",
]


def get_accept_reject_agreement():
    query_llm_reviews_for_paper = """
        select rating, decision, when_submitted 
        from genai_review join submission on paper_id = id 
        where type='neutral' and decision != "Withdrawn" and rating != "Score not found"
    """

    accept_agreement = {
        'agree': 0,
        'disagree': 0
    }

    reject_agreement = {
        'agree': 0,
        'disagree': 0
    }

    with sqlite3.connect(str(DB_PATH)) as connection:
        ratings = connection.execute(query_llm_reviews_for_paper).fetchall()
    for (genai_rating, decision, year) in ratings:
        positive_decision = decision in ACCEPTANCE_DECISIONS
        genai_positive_rating = genai_rating in POSITIVE_SCORES
        key = 'agree' if genai_positive_rating == positive_decision else 'disagree'
        decision_dict = accept_agreement if positive_decision else reject_agreement
        decision_dict[key] += 1
    return accept_agreement, reject_agreement


if __name__ == "__main__":
    accept_agreement, reject_agreement = get_accept_reject_agreement()
    plt.figure()
    x = np.array([0, 1])
    n_columns = 2  # agree or disagree
    width = 0.1

    plt.barh(x + width / 2, [accept_agreement['agree'], reject_agreement['agree']], height=width, label='Agree',
             color=POSITIVE_COLOR)
    plt.barh(x - width / 2, [accept_agreement['disagree'], reject_agreement['disagree']], height=width,
             label='Disagree', color=NEGATIVE_COLOR)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.30), ncol=2)
    plt.yticks([0, 1], ['Accept', 'Reject'])
    plt.xscale('log')
    plt.tight_layout()
    plt.show()
