R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/ead/ead_ds_analysis001.py
CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/ead_ds_analysis001.py

"""
import collections
import itertools

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets.antiderivative import antiderivative_ds
from em.util import vat_da_faak_vpn

###################################################################################

# DS_FILE_PATH0 = "~/Desktop/projects_data/extract_merge1/antiderivative/datasets/expressions001_ead.3M.00.5s.csv"
# DS_FILE_PATH1 = "~/Desktop/projects_data/extract_merge1/antiderivative/datasets/expressions001_ead.3M.01.5s.csv"

DS_FILE_PATH0 = "/fruitbasket/users/m/project_data/extract_merge1/ead1/datasets/ead_ds_002.train.csv"
DS_FILE_PATH1 = "/fruitbasket/users/m/project_data/extract_merge1/ead1/datasets/ead_ds_002.validation.csv"

###################################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
SEQUENCE_LENGTH = 128
BATCH_SIZE = 8
EVAL_BATCH_SIZE = 32

LR = 3e-5
CLIPNORM = 0.1

###################################################################################

ds0 = antiderivative_ds.load_raw_from_file(DS_FILE_PATH0, skip_unlabeled=True)
ds1 = antiderivative_ds.load_raw_from_file(DS_FILE_PATH1, skip_unlabeled=True)
stats0 = antiderivative_ds.get_stats(ds0)
stats1 = antiderivative_ds.get_stats(ds1)
print(stats0)
print(stats1)
# print(stats.total_count)

###################################################################################


def compute_example_to_repeats_count(ds):
    ret = collections.defaultdict(lambda: 0)
    for x, y in ds.as_numpy_iterator():
        x = tf.compat.as_str(x)
        ret[x] += 1
    return ret


ex_to_repeats0 = compute_example_to_repeats_count(ds0)
ex_to_repeats1 = compute_example_to_repeats_count(ds1)


def get_intersection_count(ex_to_repeats0, ex_to_repeats1):
    count = 0
    for k in ex_to_repeats0.keys():
        count += ex_to_repeats1.get(k, 0)
    for k in ex_to_repeats1.keys():
        count += ex_to_repeats0.get(k, 0)
    return count


intersection_count = get_intersection_count(ex_to_repeats0, ex_to_repeats1)
print(intersection_count)
print(intersection_count / (stats0.total_count + stats1.total_count))
# 0.14150459983793318


def get_repeats_stats(ex_to_repeats):
    ret = collections.defaultdict(lambda: 0)
    for count in ex_to_repeats.values():
        ret[count] += 1
    ret2 = (max(ret.keys()) + 1) * [0]
    for k, v in ret.items():
        ret2[k] = v
    return ret2


repeats_stats0 = get_repeats_stats(ex_to_repeats0)
repeats_stats1 = get_repeats_stats(ex_to_repeats1)
