"""Util functions for CHILDES-GPT analysis"""

import math
from typing import List, Optional

import numpy as np
import torch
import json
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizer, modeling_outputs,
                          modeling_utils)
from train.wordlevel_tokenizer import TrainableWordTokenizer  # noqa: E402


tokenizer = TrainableWordTokenizer(
    vocab_file='train/vocab.json')
word_list = ['box', 'book', 'ball', 'hand', 'paper', 'table', 'toy', 'head', 'car', 'chair', 'room', 'picture', 'doll', 'cup', 'towel', 'door', 'mouth', 'camera', 'duck', 'face', 'truck', 'bottle', 'puzzle', 'bird', 'tape', 'finger', 'bucket', 'block', 'stick', 'elephant', 'hat', 'bed', 'arm', 'dog', 'kitchen', 'spoon', 'hair', 'blanket', 'horse', 'tray', 'train', 'cow', 'foot', 'couch', 'necklace', 'cookie', 'plate', 'telephone', 'window', 'brush', 'ear', 'pig', 'purse', 'hammer', 'cat', 'shoulder',
             'garage', 'button', 'monkey', 'pencil', 'shoe', 'drawer', 'leg', 'bear', 'milk', 'egg', 'bowl', 'juice', 'ladder', 'basket', 'coffee', 'bus', 'food', 'apple', 'bench', 'sheep', 'airplane', 'comb', 'bread', 'eye', 'animal', 'knee', 'shirt', 'cracker', 'glass', 'light', 'game', 'cheese', 'sofa', 'giraffe', 'turtle', 'stove', 'clock', 'star', 'refrigerator', 'banana', 'napkin', 'bunny', 'farm', 'money']  # 100 in total. from childes_word_list intersect vsdiag vocab intersect CDI nouns catagory and take first 100

context_file_template = 'test/word_context_archive/word_context{}.json'
context_file_idxs = ['', '2', '5_0', '5_1',
                     '5_2', '5_3', '5_4', '6_0', '6_1', '6_2']
steps = [0, 150, 300, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000,
         4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000,
         9500, 10000, 10500, 11000, 11500, 12000, 12500, 13000,
         13500, 14000, 14500, 15000, 15500, 16000, 16500, 17000,
         17500, 18000, 18500, 19000, 19500, 20000]
steps2cid = {val: i for i, val in enumerate(steps)}
steps_for_plot = [0, 150, 300] + \
    list(range(500, 10001, 500)) + list(range(11000, 20001, 1000))
