# %%

import wikipedia

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from pathlib import Path


def beepboopbeep() -> None:
    print(wikipedia.search("Barack"))
    print(wikipedia.suggest("Barak Obama"))
    print(wikipedia.search("Ford", results=3))
    print(wikipedia.summary("Catgut"))


def get_random_page_summary() -> str:
   random = wikipedia.random(1)
   try:
       result = wikipedia.page(random).summary
   except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError) as e:
       result = get_random_page_summary()
   return result


def collect_wikibabble(path_save: Path, nsamples: int = 100, max_len: int = 1000) -> pd.Series:

    random_page_summaries = []

    for _ in tqdm(range(nsamples)):
        rps = get_random_page_summary()
        while len(rps) > max_len: rps = get_random_page_summary()
        random_page_summaries.append(rps)
    
    random_page_summaries = pd.Series(random_page_summaries)
    random_page_summaries.name = 'random_page_summaries'
    random_page_summaries.index.name = 'index'
    random_page_summaries.to_csv(path_save)

    return random_page_summaries
    

def get_length_stats(random_page_summaries: pd.Series, show_hist: bool = False) -> None:

    random_page_summaries_lens = random_page_summaries.apply(lambda rps: len(rps))

    print(f"Mean: {random_page_summaries_lens.mean():,.0f}")
    print(f"Median: {random_page_summaries_lens.median():,.0f}")
    print(f"Std: {random_page_summaries_lens.std():,.0f}")

    if show_hist:
        plt.figure('Distribution of lengths')
        plt.hist(random_page_summaries_lens, bins=np.arange(0,1000,50))
        plt.show()


def main():

    path_proj = Path('lm-understanding')
    
    nbatches = 20
    nsamples = 500
    
    for i in range(nbatches):
        print(f"Batch {i}")
        rps = collect_wikibabble(path_proj / 'datasets' / f'wikibabble{i}.csv', nsamples=nsamples)
        get_length_stats(rps)

    
    # Collate

    rps = pd.Series()

    for i in range(20):
        rps = pd.concat([rps, pd.read_csv(path_proj / 'datasets' / f"wikibabble{i}.csv").random_page_summaries], ignore_index=True)

    rps.name = 'random_page_summaries'
    rps.index.name = 'index'
    rps.to_csv(path_proj / 'datasets' / 'wikibabble.csv')


# %%

if __name__ == '__main__':
    main()

