import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math, re
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import sent_tokenize, word_tokenize
from datasets import Dataset
from transformers import AutoTokenizer


def get_token_length_abstractive(article, tokenizer=None):
    tokenized = tokenizer(article, max_length=None, return_tensors='pt')
    length = tokenized.input_ids.shape[-1]
    return length

def bucket_scores(sent_scores, num_buckets=10):
    scores = np.zeros(10)

    for i in range(num_buckets):
        start = math.floor(len(sent_scores) / (num_buckets) * i)
        end = math.floor(len(sent_scores) / (num_buckets) * (i + 1))
        scores[i] = (np.mean(sent_scores[start:end]))
    return scores


def bin_articles_by_length_abstractive(articles, model_name, max_range=100):
    df = pd.DataFrame(articles, columns=['article'])
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    lengths = df.article.apply(get_token_length_abstractive, tokenizer=tokenizer)
    num_bins = math.ceil((lengths.max() - lengths.min()) / max_range)
    labels = pd.cut(lengths, num_bins, labels=[x for x in range(num_bins)])
    bins = [{'text': Dataset.from_pandas(df[labels == x]),
            'min_length': lengths[labels == x].min(),
             'max_length': lengths[labels == x].max()}
            for x in range(num_bins) if (labels == x).any()]
    return bins

def bin_articles_by_length_extractive(articles, max_range=75):
    df = pd.DataFrame(articles, columns=['article'])
    lengths = df.article.apply(lambda x: len(word_tokenize(x)))
    num_bins = math.ceil((lengths.max() - lengths.min()) / max_range)
    labels = pd.cut(lengths, num_bins, labels=[x for x in range(num_bins)])
    bins = [{'text': df[labels == x].article.to_list(), 'min_length': lengths[labels == x].min(), 'max_length': lengths[labels == x].max()} for x in range(num_bins) if (labels == x).any()]
    return bins


def get_text_sentiment(text, **analyzer):
    scores = []
    for line in text:
        score = analyzer['kwargs'].polarity_scores(line)['compound']
        scores.append(score)
    return scores


def remove_words(text, word_list, remove=False):
    new_text = '\n'.join(text)

    if remove:
        for word in word_list:
            pattern = re.compile(rf'\b{word}\b')
            new_text = re.sub(pattern, '', new_text)
    else:
        for word in word_list:
            pattern = re.compile(rf'\b{word}\b')
            new_text = re.sub(pattern, word_list[word], new_text)

    return new_text

def get_overlap_scores(sentences, document):
    corpus = sentences + document
    vect = TfidfVectorizer()
    tfidf = vect.fit_transform(corpus)
    similarities = (tfidf * tfidf.T).toarray()
    
    return similarities[:len(sentences), len(sentences):]


def get_summary_indices(article, summary, top_k=1, tolerance=0.00001):
    scores = get_overlap_scores(summary, article)

    idx = scores.argmax(axis=1)
    false_idxs = np.where(scores.max(axis=1) == 0)
    idx = np.delete(idx, false_idxs)
    scores = np.delete(scores, false_idxs, axis=0)

    if top_k > 1 and len(article) > 1:
        search_idx = np.where((scores.max(axis=1) < 1-tolerance))
        biggest_idx = np.argpartition(scores[search_idx], -top_k)[:, -top_k:]
        unique_idx = np.concatenate((idx, biggest_idx.flatten()))
        unique_idx = np.unique(unique_idx)
    else:
        unique_idx = np.unique(idx)
    
    unique_idx.sort()

    return unique_idx



def bin_sentences(num_bins, sentence_index, article_length):
    bins = np.zeros(num_bins)
    for i in sentence_index:
        bin = math.floor(num_bins * i/article_length)
        bins[bin] += 1/len(sentence_index)
    return bins


def get_binned_word_counts(article, word_list, num_bins):
    bins = np.zeros(num_bins)
    for i in range(len(article)):
        bin = math.floor(num_bins*i/len(article))
        for w in word_list:
            bins[bin] += len(re.findall(rf'\b{w}\b', article[i]))

    return bins/bins.sum() if bins.sum() > 0 else bins


def visualize_summary(articles, summaries, summary_labels=None, num_bins=10, amortize=False, figsize=(10, 5), tokenize=True):
    # heatmap style visualization of where summaries came from in original text
    # if word lists provided, also highlight sentences with words from each
    # if summary labels given, give each summary a different line

    articles = list(articles)
    summaries = list(summaries)

    if tokenize:
        articles = [sent_tokenize(x) for x in articles]
        summaries = [[sent_tokenize(x) for x in y] for y in summaries]

    if num_bins == 0 and not amortize:
        num_bins = min(20, min(map(len,articles)))

    if summary_labels is None and not amortize:
        summary_labels = [f'summary {i*j+i}' for i in range(len(articles)) for j in range(len(summaries))]
    elif summary_labels is None:
        summary_labels = [f'summary {i}' for i in range(len(summaries))]


    idx = []
    bins = []


    for i in range(len(summaries)):
        idx.append([get_summary_indices(articles[j], summaries[i][j]) for j in range(len(articles))])
        bins.append([bin_sentences(num_bins, idx[i][j], len(articles[j])) for j in range(len(articles))])

    if amortize:
        amortized_bins = np.zeros((len(summaries), num_bins))
        for i in range(len(bins)):
            for j in range(len(bins[0])):
                for k in range(len(bins[0][0])):
                    amortized_bins[i,k] += bins[i][j][k]/len(articles)

        bins = amortized_bins
    else:
        bins = np.asarray(bins).reshape(len(articles)*len(summaries), num_bins)

    cols = [f'{100*(x+1)/num_bins}%' for x in range(num_bins)]
    bins_df = pd.DataFrame(bins, index=summary_labels, columns=cols)


    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(bins_df, ax=ax, cmap='OrRd')
    ax.tick_params(rotation=0)

    return fig, ax


def visualize_word_dists(articles, word_lists, index=None, num_bins=10, amortize=False, figsize=(10,5)):
    articles = list(articles)

    if index is None and not amortize:
        index = [f'article {i}' for i in range(len(articles)*len(word_lists))]
    elif index is None:
        index = [f'word list {i}' for i in range(len(word_lists))]
    
    cols = [f'{100*(x+1)/num_bins}%' for x in range(num_bins)]
    word_counts = np.ndarray(shape=(len(word_lists), len(articles), num_bins))

    for i in range(len(word_lists)):
        for j in range(len(articles)):
            word_counts[i,j] = get_binned_word_counts(articles[j], word_lists[i], num_bins)
    
    if amortize:
        word_counts = np.apply_along_axis(np.mean, 1, word_counts)
    else:
        word_counts = word_counts.reshape(-1, num_bins)  
    
    word_counts_df = pd.DataFrame(word_counts, index=index, columns=cols)
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(word_counts_df, ax=ax, cmap='BuGn')
    ax.tick_params(rotation=0)

    return fig, ax

def visualize_sentiments(text, analyzer, index=None, num_bins=10, amortize=False, figsize=(10,5)):
    cols = [f'{100*(x+1)/num_bins}%' for x in range(num_bins)]

    sentiments = text.apply(get_text_sentiment, kwargs=analyzer)
    sentiments = sentiments.apply(bucket_scores)
    sentiments = np.asarray(sentiments.to_list())

    if amortize:
        sentiments = np.apply_along_axis(np.mean, 1, sentiments)
        cols = ['score']
    if index is None:
        index = text.index

    sentiment_df = pd.DataFrame(sentiments, index=index, columns=cols)
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(sentiment_df, ax=ax, cmap='PuOr')
    ax.tick_params(rotation=0)

    return fig, ax
