# -*- coding: utf-8 -*-

import jieba
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.cluster import KMeans
import math
import numpy as np


class KmeansClustering():
    def __init__(self, stopwords_path=None):
        self.stopwords = self.load_stopwords(stopwords_path)
        self.vectorizer = CountVectorizer()
        self.transformer = TfidfTransformer()

    def load_stopwords(self, stopwords=None):
        if stopwords:
            with open(stopwords, 'r', encoding='utf-8') as f:
                return [line.strip() for line in f]
        else:
            return []

    def preprocess_data(self, documents):
        corpus = []
        all_probs = []
        for document in documents: 
            words = []
            probs = []
            prob = 1
            for token_items in document: 
                prob *= math.exp(token_items[list(token_items.keys())[0]].logprob)
                token = token_items[list(token_items.keys())[0]].decoded_token.strip()
                if len(token) == 0: 
                    continue
                words.append(token)
                if token == '.': 
                    probs.append(prob)
                    prob = 1
            corpus.append(' '.join(words))
            if len(probs) == 0: 
                all_probs.append(0)
            else: 
                all_probs.append(probs[-1])
        return corpus, all_probs

    def get_text_tfidf_matrix(self, corpus):
        tfidf = self.transformer.fit_transform(self.vectorizer.fit_transform(corpus))
        weights = tfidf.toarray()
        return weights

    def kmeans(self, documents, n_clusters=5):
        corpus, all_probs = self.preprocess_data(documents)
        weights = self.get_text_tfidf_matrix(corpus)

        clf = KMeans(n_clusters=n_clusters)

        y = clf.fit_predict(weights)
        result = {}
        for text_idx, label_idx in enumerate(y):
            if label_idx not in result:
                result[label_idx] = [text_idx]
            else:
                result[label_idx].append(text_idx)
        return result, all_probs