"""Inference caller."""
import json
import os
from transformers import GPT2Config, GPT2LMHeadModel
from checkpoint_surprisal import (checkpoint_path_to_model, checkpoint_path_to_model_ctrl, lstm_checkpoint_path_to_model, mamba_checkpoint_path_to_model, get_files_sorted,
                                  infer_model_on_contexts, new_mamba_checkpoint_path_to_model, hybrid_mamba_checkpoint_path_to_model,
                                  remove_last_occurrence, word_list2, infer_model_on_contexts_no_cmp_masking, infer_model_on_contexts_free_masking)

# nohup python -u test_batch_surprisal_list2.py > test_list2.log 2>&1 &



checkpoint_dir_template = '/scratch/[censored]_root/[censored]2/[censored]/experiments/checkpoints/{}/'
# checkpoint_dir_template = '/scratch/[censored]_root/[censored]2/[censored]/experiments/checkpoints/{}'
mode = ['agg', 'gather'][0]
is_randomized = False
if is_randomized:
    randomize_name = '_randomized'
else:
    randomize_name = ''

for sid in range(5):
    seed = sid*100+42
    print(f'seed {seed}')
    if seed <= 142:
        item_name = f'childes_warmup_s{seed}_shuffled'
    else:
        item_name = f'childes_warmup_s{seed}'
    with open('all_heads_5000.json') as fp:
        all_head_dict = json.load(fp)


    load_func = checkpoint_path_to_model
    cid_range = [12]
    # cid_range = list(range(23))
    model_id = 0  # 0: childes, 1: vsdiag
    only_logit = not ('lstm' in item_name)

    checkpoint_dir = checkpoint_dir_template.format(item_name)
    files_sorted = get_files_sorted(checkpoint_dir)
    print(files_sorted)
    for checkpoint_id in cid_range:
        print(f'\n====Now doing: {checkpoint_id} -- {files_sorted[checkpoint_id]} ====\n')

        if model_id == 0:
            context_file_template = 'word_context_archive/word_context{}.json'
            context_file_idxs = ['', '2', '5_0', '5_1', '5_2', '5_3', '5_4', '6_0', '6_1', '6_2']
        else:
            context_file_template = 'visdiag_archive/vis_context_{}_list2.json'
            context_file_idxs = list(range(1, 11))
        dir_path = f'context_{item_name}_ab_mul_result/context_list2_c{checkpoint_id}_result'
        os.makedirs(dir_path, exist_ok=True)
        result_template = dir_path+'/context{}_list2_ablation_result.json'
        ranking_result_template = dir_path+'/context{}_list2_ablation_result_ranking.json'
        updated_context_file_template = 'other/word_context{}_updated.json'
        model = load_func(files_sorted[checkpoint_id])
        # reg_term = sum((p**2).sum() for p in model.parameters() if p.requires_grad)
        # print(checkpoint_id, reg_term)
        # continue
        word_list = word_list2

        for context_id, file_idx in enumerate(context_file_idxs):
            if (checkpoint_id in [] and file_idx not in ['6_2']):
                continue
            filename = context_file_template.format(file_idx)
            print('now process: '+filename)
            with open(filename) as fp:
                content = json.load(fp)
            # word_list = list(content.keys())
            all_env = []
            all_lan = []
            updated_content = {}

            if model_id == -1:
                for k in content:
                    content[k]['env'] = k  # env have single word for childes, not vsdiag

            for word in word_list:
                env = content[word]['env'].replace('The child', '').replace('.', '')
                lan = content[word]['lan'].replace('"', '')
                lan = remove_last_occurrence(lan, word)
                updated_content[word] = {'env': env, 'lan': lan}
                all_env.append(env)
                all_lan.append(lan)

            # with open(updated_context_file_template.format(file_idx), 'w') as fp:
            #     json.dump(updated_content, fp)

            def generate_cmp_env(env, word, word_list):
                """Generate all the cmp env part."""
                res = []
                for w in word_list:
                    if w != word:
                        res.append(env.replace(word, w))
                return res

            all_cmp_env = []
            for i, word in enumerate(word_list):
                cmp_env = generate_cmp_env(all_env[i], word, word_list)
                all_cmp_env.append(cmp_env)
            surprisal_list = []
            cmp_surprisal_list = []
            ranking_list = []
            cmp_ranking_list = []
            env_list = all_env
            cmp_env_lists = all_cmp_env
            lan_list = all_lan

            all_ablation_surprisal_list = []
            all_ablation_ranking_list = []
            heads_to_mask = []
            for i in range(100):
                target_mask_head = all_head_dict[word_list[i]][sid][context_id][mode]
                heads_to_mask.append(target_mask_head)
            surprisal_list, ranking_list = infer_model_on_contexts_free_masking(model, word_list, all_env, all_cmp_env, all_lan, show_output=1, only_logit=only_logit, mask_place=heads_to_mask, return_rank=True, randomize=is_randomized)
            all_ablation_surprisal_list.append(surprisal_list)
            all_ablation_ranking_list.append(ranking_list)

            with open(result_template.format(file_idx), 'w') as fp:
                json.dump(all_ablation_surprisal_list, fp)
            with open(ranking_result_template.format(file_idx), 'w') as fp:
                json.dump(all_ablation_ranking_list, fp)