"""Inference caller."""
import json
import os
from transformers import GPT2Config, GPT2LMHeadModel
from checkpoint_surprisal import (checkpoint_path_to_model, checkpoint_path_to_model_ctrl, lstm_checkpoint_path_to_model, mamba_checkpoint_path_to_model, get_files_sorted,
                                  infer_model_on_contexts, new_mamba_checkpoint_path_to_model, hybrid_mamba_checkpoint_path_to_model,
                                  remove_last_occurrence, word_list2)

# nohup python -u test_batch_surprisal_list2.py > test_list2.log 2>&1 &

checkpoint_dir_template = '/scratch/[censored]_root/[censored]2/[censored]/experiments/checkpoints/{}/'
# checkpoint_dir_template = '/scratch/[censored]_root/[censored]2/[censored]/experiments/checkpoints/{}'
item_name = 'childes_warmup_s442_4layer_pure'
load_func = checkpoint_path_to_model_ctrl
cid_range = list(range(23))+list(range(24,43,2))
# cid_range = list(range(23))
model_id = 0  # 0: childes, 1: vsdiag
only_logit = not ('lstm' in item_name)

checkpoint_dir = checkpoint_dir_template.format(item_name)
files_sorted = get_files_sorted(checkpoint_dir)
print(files_sorted)
for checkpoint_id in cid_range:
    print(f'\n====Now doing: {checkpoint_id} -- {files_sorted[checkpoint_id]} ====\n')

    if model_id == 0:
        context_file_template = 'word_context_archive/word_context{}.json'
        context_file_idxs = ['', '2', '5_0', '5_1', '5_2', '5_3', '5_4', '6_0', '6_1', '6_2']
    else:
        context_file_template = 'visdiag_archive/vis_context_{}_list2.json'
        context_file_idxs = list(range(1, 11))
    dir_path = f'context_{item_name}_result/context_list2_c{checkpoint_id}_result'
    os.makedirs(dir_path, exist_ok=True)
    result_template = dir_path+'/context{}_list2_envsingle_result.json'
    updated_context_file_template = 'other/word_context{}_updated.json'
    model = load_func(files_sorted[checkpoint_id], hidden_size=768, layer_num=4, use_model=GPT2LMHeadModel)
    # reg_term = sum((p**2).sum() for p in model.parameters() if p.requires_grad)
    # print(checkpoint_id, reg_term)
    # continue
    word_list = word_list2

    for file_idx in context_file_idxs:
        if (checkpoint_id in [] and file_idx not in ['6_2']):
            continue
        filename = context_file_template.format(file_idx)
        print('now process: '+filename)
        with open(filename) as fp:
            content = json.load(fp)
        # word_list = list(content.keys())
        all_env = []
        all_lan = []
        updated_content = {}

        if model_id == -1:
            for k in content:
                content[k]['env'] = k  # env have single word for childes, not vsdiag

        for word in word_list:
            env = content[word]['env'].replace('The child', '').replace('.', '')
            lan = content[word]['lan'].replace('"', '')
            lan = remove_last_occurrence(lan, word)
            updated_content[word] = {'env': env, 'lan': lan}
            all_env.append(env)
            all_lan.append(lan)

        # with open(updated_context_file_template.format(file_idx), 'w') as fp:
        #     json.dump(updated_content, fp)

        def generate_cmp_env(env, word, word_list):
            """Generate all the cmp env part."""
            res = []
            for w in word_list:
                if w != word:
                    res.append(env.replace(word, w))
            return res

        all_cmp_env = []
        for i, word in enumerate(word_list):
            cmp_env = generate_cmp_env(all_env[i], word, word_list)
            all_cmp_env.append(cmp_env)
        surprisal_list = []
        cmp_surprisal_list = []
        ranking_list = []
        cmp_ranking_list = []
        env_list = all_env
        cmp_env_lists = all_cmp_env
        lan_list = all_lan

        surprisal_list, cmp_surprisal_list, ranking_list, cmp_ranking_list = infer_model_on_contexts(model, word_list, all_env, all_cmp_env, all_lan, show_output=1, only_logit=only_logit)
        datapoints = {'surprisal_list': surprisal_list, 'cmp_surprisal_list': cmp_surprisal_list, 'ranking_list': ranking_list, 'cmp_ranking_list': cmp_ranking_list}
        if checkpoint_id in [] and file_idx in []:
            with open(result_template.format(file_idx+'_val'), 'w') as fp:
                json.dump(datapoints, fp)
        else:
            with open(result_template.format(file_idx), 'w') as fp:
                json.dump(datapoints, fp)
