import pandas as pd
from nltk import sent_tokenize
import numpy as np
import seaborn as sns
from tqdm import tqdm
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from utils import *
import matplotlib.pyplot as plt
import os
from datasets import load_dataset

tqdm.pandas()

def gen_sent_csv(data_path, save_path, analyzer, col='summary'):
    df = pd.read_csv(data_path, index_col=0).dropna()
    #df = pd.DataFrame(load_dataset('cnn_dailymail', '3.0.0', split='train+validation+test'))
    print('tokenizing sentences')
    text = df[col].progress_apply(sent_tokenize)
    print('getting sentiments')
    sentiments = text.progress_apply(get_text_sentiment, kwargs=analyzer)
    mean_sentiments = sentiments.apply(np.mean).tolist()
    sentiment_df = pd.DataFrame(zip(df.id, mean_sentiments), columns=['id', 'sentiment'])
    sentiment_df.to_csv(save_path)
    return sentiment_df

def plot_sent_hist(data_path, save_dir, summarizer, analyzer, col='summary', bins=50):
    sentiment_df = gen_sent_csv(data_path, f'{save_dir}/{summarizer}_sentiment.csv', analyzer, col)
    plt.clf()
    sns.histplot(data=sentiment_df, bins=bins)
    plt.savefig(f'{save_dir}/{summarizer}_sent_plot.png')


if __name__ == '__main__':
    summarizers = ['textrank', 'matchsum', 'presumm_ext', 'presumm_abs', 'bart', 'pegasus', 'azure']
    analyzer = SentimentIntensityAnalyzer()

    #plot_sent_hist('cnn_dailymail', '/home/user/user/sentiment_analysis', 'articles', analyzer, col='article')
    #gen_sent_csv('~/user/textrank.csv', '~/user/sentiment_analysis/articles_sentiment.csv', analyzer, col='article')

    for s in summarizers:
        data_path = os.path.expanduser(f'~/user/{s}.csv')
        save_dir = os.path.expanduser(f'~/user/sentiment_analysis')
        save_path = os.path.expanduser(f'~/user/sentiment_analysis/{s}_sentiment.csv')
        print(f'generating csv for {s}!')
        gen_sent_csv(data_path, save_path, analyzer, col='summary')
    
    
    
