"""
Compute the number of words that are used by ChatPDF to create reviews.
"""
import pickle
import sqlite3
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from numpy import mean
from tqdm import tqdm

from config import DB_PATH
from dataset_metrics.extract_info_from_reviews import biases, get_summary_main_length, color_biases, \
    EXTENDED_SCORES, TICKS_SCORES

font_size_legend = 12


def get_main_summary_review_length():
    main_hist_np = "main_hist.np"
    summary_hist_np = "summary_hist.np"
    if Path(summary_hist_np).exists() and Path(main_hist_np).exists():
        with open(summary_hist_np, 'rb') as f:
            summary_rev_histograms = pickle.load(f)
        with open(main_hist_np, 'rb') as f:
            main_rev_histograms = pickle.load(f)
        return summary_rev_histograms, main_rev_histograms
    summary_rev_histograms = {
        "positive": {s: [] for s in EXTENDED_SCORES},
        "negative": {s: [] for s in EXTENDED_SCORES},
        "neutral": {s: [] for s in EXTENDED_SCORES},
    }

    main_rev_histograms = {
        "positive": {s: [] for s in EXTENDED_SCORES},
        "negative": {s: [] for s in EXTENDED_SCORES},
        "neutral": {s: [] for s in EXTENDED_SCORES},
    }
    with sqlite3.connect(str(DB_PATH)) as connection:
        all_genai_reviews = pd.read_sql_query("select * from genai_review", connection)
        # n_reviews_per_rating = {
        #     k: connection.execute("select count(*) from genai_review where rating = ?", [k]).fetchone()[0] for
        #     k in EXTENDED_SCORES}
        alerts = []
        for _, genai_review in tqdm(all_genai_reviews.iterrows()):
            bias = genai_review['type']
            rating = genai_review['rating']
            generated = genai_review['generated']
            result = get_summary_main_length(generated)
            if result is not None:
                summary, main = result
                summary_rev_histograms[bias][rating].append(summary)
                main_rev_histograms[bias][rating].append(main)
            else:
                alerts.append(genai_review)
        with open(summary_hist_np, 'wb') as f:
            pickle.dump(summary_rev_histograms, f)
        with open(main_hist_np, 'wb') as f:
            pickle.dump(main_rev_histograms, f)
    return summary_rev_histograms, main_rev_histograms


summary_rev_histograms, main_rev_histograms = get_main_summary_review_length()
x = np.arange(len(EXTENDED_SCORES))
width = 0.2
n_biases = len(biases)
columns = 1

bbox_to_anchor = (0.15, 1.15)
# Plot 1: Summary
fig1, ax1 = plt.subplots()
for i, bias in enumerate(biases):
    ax1.bar(x + i * width, [mean(summary_rev_histograms[bias][k]) for k in EXTENDED_SCORES], width=width,
            label=bias, color=color_biases[i])
ax1.set_xticks(x + width * (n_biases - 1) / 2)
ax1.set_xticklabels(TICKS_SCORES)
# ax1.legend(prop={'size': font_size_legend}, bbox_to_anchor=bbox_to_anchor, ncol=columns)

# Plot 2: Main Review
fig2, ax2 = plt.subplots()
for i, bias in enumerate(biases):
    ax2.bar(x + i * width, [mean(main_rev_histograms[bias][k]) for k in EXTENDED_SCORES], width=width,
            label=bias, color=color_biases[i])
ax2.set_xticks(x + width * (n_biases - 1) / 2)
ax2.set_xticklabels(TICKS_SCORES)
ax2.legend(prop={'size': font_size_legend}, bbox_to_anchor=bbox_to_anchor, ncol=columns)

plt.show()
