import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def load_events_ds():
    df = pd.read_csv("Crowdflower_Comparisons/EventTime/events.txt")
    # selected from original processing script final_nodes in function filter_events()
    ids = [
        4,
        5,
        6,
        19,
        32,
        34,
        47,
        55,
        57,
        67,
        77,
        79,
        83,
        89,
        91,
        95,
        97,
        98,
        99,
        102,
        109,
        110,
        111,
        112,
        113,
        117,
        120,
        122,
        123,
        125,
        127,
        128,
        131,
        133,
        135,
        136,
        137,
        140,
        141,
        143,
        145,
        146,
        148,
        149,
        150,
        151,
        153,
        154,
        156,
        158,
        161,
        162,
        163,
        164,
        165,
        167,
        169,
        170,
        171,
        173,
        174,
        175,
        176,
        177,
        179,
        180,
        183,
        184,
        185,
        187,
        192,
        194,
        195,
        196,
        200,
        201,
        203,
        204,
        212,
        214,
        221,
        222,
        228,
        232,
        233,
        234,
        237,
        238,
        240,
        242,
        251,
        256,
        259,
        261,
        262,
        264,
        278,
        283,
        290,
        316,
    ]
    # random.shuffle(ids)
    id_set = set(ids)
    K = len(ids)
    id_map = dict(zip(ids, range(K)))
    W = np.zeros((K, K))
    for row in df.itertuples():
        try:
            # if True:
            id1 = np.int32(row.id1)
            id2 = np.int32(row._22)
            if not (id1 in id_set and id2 in id_set):
                continue
            i = id_map[id1]
            j = id_map[id2]
            if row.category == "category1":
                W[i, j] += 1
            else:
                W[j, i] += 1
        except:
            print(row)
            pass
    p = W / (W + W.T + 1e-6)
    return p


events_ds = load_events_ds()
# plt.imshow(events_ds)
plt.imsave("events_ds.png", events_ds)
np.savez("events_ds.npz", events_ds=events_ds)
# input("Press Enter to continue...")
