import pickle
import numpy as np

num_samples = 160000
N = 512

with open("list_of_tagged_sequence.pkl", "rb") as f:
    list_of_tagged_sequence = pickle.load(f)

POS_list = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SPACE', 'SYM', 'VERB', 'X']
POS2id = {POS:i for i, POS in enumerate(POS_list)}

data_proportion = []
for tagged_sequence in list_of_tagged_sequence:
    if len(data_proportion) == num_samples:
        break
    if len(tagged_sequence) < N:
        continue
    count = [0 for _ in range(len(POS_list))]
    for _, POS in tagged_sequence[:N:]:
        count[POS2id[POS]] += 1
    data_proportion.append(np.array(count)/N)
data_proportion = np.array(data_proportion, dtype=np.float32)

mean_proportion = np.mean(data_proportion, axis=0)
mean_cooccurrence = np.mean(data_proportion[:,:,np.newaxis] * data_proportion[:,np.newaxis,:], axis=0)
integrated_correlation = N * (mean_cooccurrence - mean_proportion[:,np.newaxis] * mean_proportion[np.newaxis,:])