"""Calculate checkpoint's surprisal with respect to context."""

import glob
import json
import os
import sys

import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPT2LMHeadModel
from utils import next_token_surprisal

sys.path.insert(0, '/home/[censored]/mamba_lstm/trabank-dev/train')
from gpt2_train_fromhf import SimpleLMHeadModelNoFFN
from lstm import LSTMLayer
from mamba import MambaLayer
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
sys.path.insert(0, '../train')

from wordlevel_tokenizer import TrainableWordTokenizer  # noqa: E402
from gpt2CustomModel import GPT2LMHeadModelNoResidual
from typing import Optional, Tuple

def extract_step(fname):
    """Extracts the step number from a filename of the form: 'checkpoint_X_YYYY.pt' where X can be any integer index and YYYY is the step number."""
    basename = os.path.basename(fname)
    # e.g., 'checkpoint_0_150.pt' -> parts = ['checkpoint', '0', '150.pt']
    parts = basename.split('_')
    step_str = parts[-1].replace('.pt', '')
    return int(step_str)


def get_files_sorted(dir: str):
    """Generate a list of checkpoint in order, given the checkpoint dir."""
    c_pattern = os.path.join(dir, 'checkpoint_*.pt')
    c_files = glob.glob(c_pattern)
    c_files_sorted = sorted(c_files, key=extract_step)
    return c_files_sorted

import random
from collections import defaultdict
from typing import List, Tuple, Iterable, Optional

def pick_disjoint_pairs(pairs: List[List[int]], 
                        h_range: Iterable[int] = range(12),
                        seed=42) -> List[List[int]]:
    """
    Given a list of unique [l, h] pairs with 0 <= l,h <= 11 and 
    at most 6 pairs per l, return a new list of pairs such that:
      - For each l, we output the same number of pairs as appear with that l.
      - Each output pair has the same l but an h not used by that l in the input.
      - Output pairs for the same l have distinct h's.
    Always possible under the stated constraints (since |h_range|=12 and <=6 used per l).
    
    Args:
        pairs: input list of [l, h].
        h_range: iterable of candidate h values (default 0..11).
        seed: optional RNG seed for reproducibility.

    Returns:
        A list of [l, h'] pairs meeting the criteria.
    """
    if seed is not None:
        random.seed(seed)

    # Map each l -> set of used h's and count
    used_by_l = defaultdict(set)
    for l, h in pairs:
        used_by_l[l].add(h)

    out: List[List[int]] = []
    for l, used_hs in used_by_l.items():
        need = len(used_hs)  # same number as in input for this l
        candidates = [h for h in h_range if h not in used_hs]
        # Under constraints, len(candidates) >= need
        chosen = random.sample(candidates, k=need)
        out.extend([[l, h] for h in chosen])

    return out

def pick_disjoint_pairs_relaxed(
    pairs: List[List[int]],
    h_range: Iterable[int] = range(12),
    seed=42
) -> List[List[int]]:
    """
    Given unique [l, h] pairs with 0 <= l,h <= 11 (but possibly >6 pairs per l),
    produce a list where for each l:
      - If count(l) <= 6: return count(l) pairs of [l, h'] with h' not used by l,
        all distinct among themselves.
      - If count(l) > 6: return as many [l, h'] as possible with unused h',
        and for the remainder, return the original [l, h] pairs (unchanged).
    
    Returns:
        A list of [l, h*] pairs as described.
    """
    if seed is not None:
        random.seed(seed)

    # group by l
    by_l = defaultdict(list)
    for l, h in pairs:
        by_l[l].append(h)

    out: List[List[int]] = []

    for l, hs in by_l.items():
        used_hs = set(hs)  # unique original h's for this l
        cnt = len(hs)
        candidates = [h for h in h_range if h not in used_hs]
        # how many new pairs we can form with unused h
        can_make = min(cnt, len(candidates))

        # choose which unused h to use
        chosen_new_h = random.sample(candidates, k=can_make) if can_make > 0 else []

        # add new pairs
        out.extend([[l, h] for h in chosen_new_h])

        # if we still need more to match the original count, return the rest of originals
        remaining = cnt - can_make
        if remaining > 0:
            # choose which originals to return (random subset for fairness)
            keep_originals = random.sample(hs, k=remaining)
            out.extend([[l, h] for h in keep_originals])

    return out

CHECKPOINTS_DIR = '/scratch/[censored]_root/[censored]2/[censored]/experiments/checkpoints/childes_warmup_s42_shuffled/'
files_sorted = get_files_sorted(CHECKPOINTS_DIR)  # will be overrided

# word_list = ['ball', 'doll', 'chicken', 'fish', 'cat', 'dog', 'banana', 'book', 'candy', 'telephone']

# word_list1 = ['box', 'book', 'ball', 'hand', 'floor', 'paper', 'page', 'table', 'toy', 'head', 'car', 'chair', 'piece', 'room', 'picture', 'doll', 'house', 'cup', 'puppet', 'top', 'towel', 'door', 'lid', 'mouth', 'camera', 'duck', 'phone', 'ring', 'face', 'truck', 'bottle', 'puzzle', 'bird', 'tape', 'desk', 'clown', 'bag', 'finger', 'bucket', 'block', 'stick', 'baby', 'elephant', 'wall', 'magnet', 'hat', 'bed', 'arm', 'dog', 'kitchen', 'pot', 'hole', 'spoon', 'football', 'recorder', 'hair', 'cloth', 'blanket', 'horse', 'tray', 'train', 'cow', 'seat', 'foot', 'tower', 'circle', 'mirror', 'couch', 'necklace', 'cookie', 'jeep', 'plate', 'telephone', 'window', 'microphone', 'brush', 'ear', 'string', 'orange', 'rug', 'pig', 'purse', 'hammer', 'cat', 'shoulder', 'garage', 'wheel', 'button', 'monkey', 'pencil', 'stool', 'shoe', 'drawer', 'teapot', 'leg', 'bear', 'girl', 'bottom', 'milk', 'egg']  # 100 in total. from childes_word_list intersect vsdiag vocab, take concrete noun (w/ chatgpt 4o) and take first 100

word_list2 = ['box', 'book', 'ball', 'hand', 'paper', 'table', 'toy', 'head', 'car', 'chair', 'room', 'picture', 'doll', 'cup', 'towel', 'door', 'mouth', 'camera', 'duck', 'face', 'truck', 'bottle', 'puzzle', 'bird', 'tape', 'finger', 'bucket', 'block', 'stick', 'elephant', 'hat', 'bed', 'arm', 'dog', 'kitchen', 'spoon', 'hair', 'blanket', 'horse', 'tray', 'train', 'cow', 'foot', 'couch', 'necklace', 'cookie', 'plate', 'telephone', 'window', 'brush', 'ear', 'pig', 'purse', 'hammer', 'cat', 'shoulder', 'garage', 'button', 'monkey', 'pencil', 'shoe', 'drawer', 'leg', 'bear', 'milk', 'egg', 'bowl', 'juice', 'ladder', 'basket', 'coffee', 'bus', 'food', 'apple', 'bench', 'sheep', 'airplane', 'comb', 'bread', 'eye', 'animal', 'knee', 'shirt', 'cracker', 'glass', 'light', 'game', 'cheese', 'sofa', 'giraffe', 'turtle', 'stove', 'clock', 'star', 'refrigerator', 'banana', 'napkin', 'bunny', 'farm', 'money']  # 100 in total. from childes_word_list intersect vsdiag vocab intersect CDI nouns catagory and take first 100

# childes_word_list = ['box', 'book', 'ball', 'hand', 'floor', 'paper', 'page', 'table', 'toy', 'crayon', 'head', 'car', 'Chi', 'chair', 'jack', 'piece', 'room', 'picture', 'mother', 'doll', 'front', 'house', 'cup', 'silver', 'man', 'puppet', 'top', 'towel', 'Mother', 'door', 'lid', 'mouth', 'camera', 'put', 'duck', 'phone', 'side', 'ring', 'shape', 'Mot', 'noise', 'face', 'truck', 'bottle', 'cover', 'yellow', 'puzzle', 'bird', 'tape', 'desk', 'clown', 'lap', 'handle', 'bag', 'Patsy', 'finger', 'child', 'bucket', 'block', 'stick', 'baby', 'closer', 'elephant', 'wall', 'magnet', 'hat', 'look', 'bed', 'arm', 'dog', 'play', 'kitchen', 'pot', 'turn', 'hole', 'spoon', 'football', 'Sarah', 'recorder', 'hair', 'cloth', 'blanket', 'horse', 'tray', 'train', 'end', 'rattle', 'cow', 'plastic', 'seat', 'foot', 'tower', 'air', 'circle', 'draw', 'laugh', 'pick', 'stack', 'mirror', 'couch', 'necklace', 'cookie', 'jeep', 'place', 'hold', 'plate', 'blue', 'telephone', 'window', 'microphone', 'brush', 'ear', 'stand', 'string', 'orange', 'rug', 'right', 'fall', 'Roman', 'slide', 'pig', 'way', 'forth', 'purse', 'hammer', 'voice', 'pull', 'figure', 'cat', 'shoulder', 'garage', 'wheel', 'sound', 'talk', 'left', 'button', 'monkey', 'pencil', 'stool', 'living', 'shoe', 'drawer', 'teapot', 'Jenny', 'leg', 'push', 'bear', 'move', 'screwdriver', 'girl', 'bottom', 'milk', 'person', 'set', 'fit', 'egg', 'water', 'bowl', 'nut', 'juice', 'fish', 'part', 'board', 'corner', 'self', 'ladder', 'line', 'boy', 'basket', 'bit', 'saw', 'coffee', 'bus', 'break', 'point', 'food', 'ground', 'time', 'edge', 'register', 'help', 'letter', 'attention', 'barrel', 'drink', 'half', 'reading', 'apple', 'middle', 'pile', 'walk', 'hall', 'music', 'pat', 'Bird', 'show', 'roll', 'round', 'upside', 'name', 'color', 'bench', 'word', 'cry', 'chicken', 'pocketbook', 'rubber', 'song', 'reach', 'sheep', 'neck', 'father', 'airplane', 'wrench', 'opening', 'comb', 'bread', 'stomach', 'building', 'track', 'someone', 'eye', 'tea', 'touch', 'animal', 'sponge', 'iron', 'watch', 'work', 'cash', 'cut', 'teddy', 'dump', 'dinosaur', 'knee', 'shirt', 'bath', 'shelf', 'pretend', 'climb', 'marble', 'cracker', 'glass', 'light', 'straight', 'game', 'try', 'peekaboo', 'cheese', 'bike', 'Mom', 'rabbit', 'throw', 'clip', 'answer', 'triangle', 'chest', 'plane', 'sofa', 'rest', 'giraffe', 'past', 'wagon', 'turtle', 'stove', 'Dad', 'Mark', 'bite', 'pipe', 'clock', 'tire', 'sugar', 'drinking', 'star', 'snake', 'bar', 'square', 'cord', 'lorry', 'snack', 'space', 'card', 'refrigerator', 'thumb', 'metal', 'fire', 'start', 'banana', 'jigsaw', 'napkin', 'bell', 'read', 'catch', 'bunny', 'farm', 'grab', 'fix', 'bridge', 'crawl', 'spot', 'dough', 'barrette', 'money', 'use', 'piano']  # 305 in total. from chldes vocab, take all words with freq >= 100 in both <LAN> and <ENV>, and pass nltk noun check wn.synsets(word, pos='n')

tokenizer = TrainableWordTokenizer(vocab_file='../train/vocab.json')


def checkpoint_path_to_model(path):
    """Load checkpoint to model."""
    checkpoint = torch.load(path)
    model = GPT2LMHeadModel(config=GPT2Config(n_layer=12))
    model.resize_token_embeddings(len(tokenizer))
    # 
    model.load_state_dict(checkpoint['model_state_dict'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model

def checkpoint_path_to_model_ctrl(path, hidden_size=768, use_model=GPT2LMHeadModelNoResidual, layer_num=4):
    """Load checkpoint to model."""
    checkpoint = torch.load(path)
    model = use_model(config=GPT2Config(n_embd=hidden_size, n_layer=layer_num))
    model.resize_token_embeddings(len(tokenizer))
    # 
    model.load_state_dict(checkpoint['model_state_dict'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model

def lstm_checkpoint_path_to_model(path):
    checkpoint = torch.load(path)
    print(len(tokenizer))
    # print(tokenizer._vocab)
    model = SimpleLMHeadModelNoFFN(
        layer=LSTMLayer,
        d_model=768,
        vocab_size=len(tokenizer),
        n_layer=4,
        max_position_embeddings=1024,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        embed_dropout=0
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    return model


def mamba_checkpoint_path_to_model(path):
    checkpoint = torch.load(path)
    print(len(tokenizer))
    # print(tokenizer._vocab)
    model = SimpleLMHeadModelNoFFN(
        layer=MambaLayer,
        d_model=768,
        vocab_size=len(tokenizer),
        n_layer=4,
        max_position_embeddings=1024,
        device='cuda' if torch.cuda.is_available() else 'cpu',
        embed_dropout=0
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model

def new_mamba_checkpoint_path_to_model(path):
    checkpoint = torch.load(path)
    mamba_config = MambaConfig(d_model=768, d_intermediate=0, n_layer=12, vocab_size=len(tokenizer), 
        ssm_cfg={'layer': 'Mamba2'}, attn_layer_idx=[], attn_cfg={}, rms_norm=True, residual_in_fp32=True, 
        fused_add_norm=True, pad_vocab_size_multiple=16, tie_embeddings=True)
    model = MambaLMHeadModel(mamba_config)
    model.load_state_dict(checkpoint['model_state_dict'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model

def hybrid_mamba_checkpoint_path_to_model(path, attn_idx=[]):
    checkpoint = torch.load(path)
    mamba_config = MambaConfig(d_model=768, d_intermediate=0, n_layer=12, vocab_size=len(tokenizer), 
        ssm_cfg={'layer': 'Mamba2'}, attn_layer_idx=attn_idx, attn_cfg={'causal': True, 'num_heads': 12}, rms_norm=True, residual_in_fp32=True, 
        fused_add_norm=True, pad_vocab_size_multiple=16, tie_embeddings=True)
    model = MambaLMHeadModel(mamba_config)
    model.load_state_dict(checkpoint['model_state_dict'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model


def add_tag(text, tag=':<LAN>'):
    """Add tag to plain text."""
    words = text.split()
    for i in range(len(words)):
        words[i] += tag
    return ' '.join(words)


def make_context(env, lan):
    """Make context from env and lan. Example: env = 'holds a book', lan = 'this is a'."""
    context = '<CHI> ' + add_tag(env, ':<ENV>') + ' <CHI> ' + add_tag(lan)
    return context


def make_plain_context(env, lan):
    """Make context from env and lan without adding tags."""
    return env + ' ' + lan


def make_auto_context(word):
    """Automaticaly makes context from a word."""
    vowel = 'aeioAEIO'  # do not take u
    if word[0] in vowel:
        return make_context(f'holds an {word}', 'this is an')
    else:
        return make_context(f'holds a {word}', 'this is a')


def infer_model(model, word_list, show_output=False):
    """Surprisal of word as next token of default context and cmp_context; ranking compare."""
    model.eval()
    surprisal_list = []
    cmp_surprisal_list = []
    ranking_list = []
    cmp_ranking_list = []
    for next_token in word_list:
        env = f'holds a {next_token}'
        cmp_env = 'holds a <unk>'
        context = '<CHI> ' + add_tag(env, ':<ENV>') + ' <CHI> ' + add_tag('this is a')
        cmp_context = '<CHI> ' + add_tag(cmp_env, ':<ENV>') + ' <CHI> ' + add_tag('this is a')
        surprisal = next_token_surprisal('', context, add_tag(next_token), tokenizer, model)
        cmp_surprisal = next_token_surprisal('', cmp_context, add_tag(next_token), tokenizer, model)
        if show_output:
            print(f"for context='{context}', cmp_context='{cmp_context}' next_token='{next_token}' surprisal={surprisal}, cmp_surprisal={cmp_surprisal}")

        surprisal_list.append(surprisal)
        cmp_surprisal_list.append(cmp_surprisal)
        # rank
        input_ids = tokenizer.encode(context, return_tensors='pt')
        cmp_input_ids = tokenizer.encode(cmp_context, return_tensors='pt')
        candidate_id = tokenizer.convert_tokens_to_ids(add_tag(next_token))
        if candidate_id is None:
            print(f"Token '{next_token}' not found in the vocabulary!")
        else:
            with torch.no_grad():
                outputs = model(input_ids)
                logits = outputs.logits
            next_token_logits = logits[0, -1, :]
            candidate_logit = next_token_logits[candidate_id]
            rank = (next_token_logits > candidate_logit).sum().item() + 1
            ranking_list.append(rank)
        # cmp_rank
        if cmp_input_ids is None:
            print(f"Token '{next_token}' not found in the vocabulary!")
        else:
            with torch.no_grad():
                outputs = model(cmp_input_ids)
                logits = outputs.logits
            next_token_logits = logits[0, -1, :]
            candidate_logit = next_token_logits[candidate_id]
            rank = (next_token_logits > candidate_logit).sum().item() + 1
            cmp_ranking_list.append(rank)
    return surprisal_list, cmp_surprisal_list, ranking_list, cmp_ranking_list


def get_embedding(model, word_list):
    """Given model and word_list, calculate each token's embedding and return as list of tensor (1, embedding_dim)."""
    result = []
    for token in word_list:
        token_id = tokenizer.convert_tokens_to_ids(token)
        input_tensor = torch.tensor([[token_id]])
        embedding_layer = model.transformer.wte
        embedding_vector = embedding_layer(input_tensor)[0]
        result.append(embedding_vector)
    return result


def remove_last_occurrence(lan: str, word: str) -> str:
    """Remove last occurrence of word in lan."""
    lower_lan = lan.lower()
    lower_word = word.lower()

    pos = lower_lan.rfind(lower_word)

    # If the word is found, return everything before that occurrence.
    # Otherwise, return the original string.
    return lan[:pos] if pos != -1 else lan


@torch.no_grad()
def surprisal_with_head_mask(
    model,                    # GPT2LMHeadModel
    tokenizer,                # GPT2 tokenizer (byte-level BPE)
    context: str,
    target: str,              # the next word you want probability for
    mask_place: Optional[Tuple[int, int]] = None,  # (layer_index, head_index) to zero
    device: Optional[torch.device] = None,
) -> float:
    """
    Return surprisal of `target` given `context` in **bits**:
        surprisal = -log2 P(target | context)
    If mask_place=(L,H) is provided, zeros out that attn head for the forward pass.
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device

    # --- Context-aware target id (ensure it's a single token in this context) ---
    context_ids = tokenizer.encode(context, add_special_tokens=False)
    full_ids = tokenizer.encode(context + " " + target, add_special_tokens=False)

    if len(full_ids) != len(context_ids) + 1:
        raise ValueError(
            f"Target is not a single token in this context. "
            f"len(full_ids)={len(full_ids)}, len(context_ids)={len(context_ids)}. "
            f"Consider adjusting spacing (e.g., leading space) so it tokenizes as one token."
        )

    target_id = full_ids[-1]

    # --- Build optional head mask ---
    # Shape: [num_layers, num_heads], ones everywhere except the (L,H) to zero if requested.
    head_mask = None
    if mask_place is not None:
        L, H = mask_place
        n_layers = model.config.n_layer
        n_heads  = model.config.n_head
        if not (0 <= L < n_layers and 0 <= H < n_heads):
            raise ValueError(f"mask_place {(L, H)} out of range: layers=[0..{n_layers-1}], heads=[0..{n_heads-1}]")
        head_mask = torch.ones(n_layers, n_heads, device=device)
        head_mask[L, H] = 0.0

    # --- Forward: get next-token distribution given the context only ---
    inp = torch.tensor([context_ids], dtype=torch.long, device=device)
    out = model(input_ids=inp, head_mask=head_mask, use_cache=False)
    logits_last = out.logits[:, -1, :]  # [batch=1, vocab]
    probs_last  = torch.softmax(logits_last, dim=-1)

    p = probs_last[0, target_id].clamp(min=1e-45)  # avoid log(0)
    surprisal_bits = -torch.log(p).item()
    return surprisal_bits

def infer_model_on_contexts(model, word_list, env_list, cmp_env_lists, lan_list, show_output=0, env_tag=':<ENV>', only_logit=False):
    """Surprisal of word as next token of given env_list, cmp_env_lists (list of env list), lan_list (removed target word) (untagged); ranking compare."""
    model.eval()
    device='cuda' if torch.cuda.is_available() else 'cpu'
    surprisal_list = []
    cmp_surprisal_list = []
    ranking_list = []
    cmp_ranking_list = []


    for i, next_token in enumerate(word_list):
        env = '<CHI> ' + add_tag(env_list[i].replace('The child', ''), env_tag)  # <CHI> <ENV>(holds an apple)
        plain_cmp_envs = cmp_env_lists[i]  # [holds an animal, holds an egg, ...]
        cmp_envs = list(map(lambda x: '<CHI> ' + add_tag(x.replace('The child', ''), env_tag), plain_cmp_envs))  # [<CHI> <ENV>(holds an animal), <CHI> <ENV>(holds an egg), ...]
        lan = add_tag(lan_list[i])
        context = env + ' <CHI> ' + lan
        all_contexts = [(cmp_env + ' <CHI> ' + lan) for cmp_env in cmp_envs]
        all_contexts.insert(0, context)  # first item is context, rest is cmp. Same length
        # rank, todo
        token_ids = tokenizer.encode(add_tag(next_token), add_special_tokens=False)
        if len(token_ids) != 1:
            raise ValueError('next_token should tokenize to exactly one token.')
        target_token_id = token_ids[0]

        # Tokenize all contexts in a batch with padding.
        encoded = tokenizer(all_contexts, return_tensors='pt', padding=False)
        input_ids = encoded['input_ids'].to(device)
        # print(input_ids)
        with torch.no_grad():
            if not only_logit:
                outputs = model(input_ids)[0]
            else:
                outputs = model(input_ids)
            # Get the logits for the next-token prediction for each context.
            # outputs.logits has shape (batch_size, seq_length, vocab_size)
            # print(outputs)
            logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            # print(logits.shape)
            log_probs = F.log_softmax(logits, dim=-1)
            surprisals = (-log_probs[:, target_token_id]).tolist()
            surprisal_list.append(surprisals[0])
            cmp_surprisal_list.append(surprisals[1:])

        if show_output:
            print(f"for context='{context}' cmp_contexts='{all_contexts[1:]}' next_token='{next_token}' surprisal={surprisals[0]}, cmp_surprisal={surprisals[1:]}")
            show_output -= 1

        # For each context, get the logit for the target token.
        target_logits = logits[:, target_token_id].unsqueeze(1)  # Shape: (batch_size, 1)

        # Compute the rank: count how many tokens have a logit greater than the target token.
        # The rank is 1 plus the number of tokens with higher logit.
        ranks = ((logits > target_logits).sum(dim=1) + 1).tolist()
        ranking_list.append(ranks[0])
        cmp_ranking_list.append(ranks[1:])
        if i % 10 == 0 or i == 99:
            print(f'{i} completed')

    return surprisal_list, cmp_surprisal_list, ranking_list, cmp_ranking_list

def infer_model_on_contexts_no_cmp_masking(model, word_list, env_list, cmp_env_lists, lan_list, show_output=0, env_tag=':<ENV>', only_logit=False, mask_place: Optional[Tuple[int, int]] = None, rev_mask=False, return_rank=False, logit_lens=False):
    """Surprisal of word as next token of given env_list, cmp_env_lists (list of env list), lan_list (removed target word) (untagged); ranking compare."""
    model.eval()
    device='cuda' if torch.cuda.is_available() else 'cpu'
    surprisal_list = []
    rank_list = []
    head_mask = None
    if mask_place is not None:
        L, H = mask_place
        n_layers = model.config.n_layer
        n_heads  = model.config.n_head
        if not (0 <= L < n_layers and 0 <= H < n_heads):
            raise ValueError(f"mask_place {(L, H)} out of range: layers=[0..{n_layers-1}], heads=[0..{n_heads-1}]")
        head_mask = torch.ones(n_layers, n_heads, device=device)
        if not rev_mask:
            head_mask[L, H] = 0.0
        else:
            for h in range(12):
                if h != H:
                    head_mask[L, h] = 0.0
        if logit_lens:
            for l in range(L+1, 12):
                for h in range(12):
                    head_mask[l, h] = 0.0

    for i, next_token in enumerate(word_list):
        env = '<CHI> ' + add_tag(env_list[i].replace('The child', ''), env_tag)  # <CHI> <ENV>(holds an apple)
        lan = add_tag(lan_list[i])
        context = env + ' <CHI> ' + lan
        all_contexts=[context]  # first item is context, rest is cmp. Same length
        # rank, todo
        token_ids = tokenizer.encode(add_tag(next_token), add_special_tokens=False)
        if len(token_ids) != 1:
            raise ValueError('next_token should tokenize to exactly one token.')
        target_token_id = token_ids[0]

        # Tokenize all contexts in a batch with padding.
        encoded = tokenizer(all_contexts, return_tensors='pt', padding=False)
        input_ids = encoded['input_ids'].to(device)
        # print(input_ids)
        with torch.no_grad():
            if not only_logit:
                outputs = model(input_ids, head_mask=head_mask)[0]
            else:
                outputs = model(input_ids, head_mask=head_mask)
            # Get the logits for the next-token prediction for each context.
            # outputs.logits has shape (batch_size, seq_length, vocab_size)
            # print(outputs)
            logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            # print(logits.shape)
            log_probs = F.log_softmax(logits, dim=-1)
            surprisals = (-log_probs[:, target_token_id]).tolist()
            surprisal_list.append(surprisals[0])
            target_logp = log_probs[:, target_token_id]                     # shape (1,)
            rank = (log_probs > target_logp.unsqueeze(-1)).sum(dim=-1) + 1  # shape (1,)
            rank_list.append(int(rank.item()))

        if show_output:
            print(f"for context='{context}' next_token='{next_token}' surprisal={surprisals[0]}")
            show_output -= 1

        if i % 10 == 0 or i == 99:
            print(f'{i} completed')
    if return_rank:
        return surprisal_list, rank_list
    return surprisal_list


def infer_model_on_contexts_free_masking(model, word_list, env_list, cmp_env_lists, lan_list, show_output=0, env_tag=':<ENV>', only_logit=False, mask_place = [[]], return_rank=False, randomize=False):
    """Surprisal of word as next token of given env_list, cmp_env_lists (list of env list), lan_list (removed target word) (untagged); ranking compare."""
    model.eval()
    device='cuda' if torch.cuda.is_available() else 'cpu'
    surprisal_list = []
    rank_list = []
    

    for i, next_token in enumerate(word_list):
        head_mask = torch.ones(12, 12, device=device)
        if randomize:
            masks = pick_disjoint_pairs_relaxed(mask_place[i])
            for mask_head in masks:
                l, h = mask_head[0], mask_head[1]
                head_mask[l, h] = 0.0
        else:
            for mask_head in mask_place[i]:
                l, h = mask_head[0], mask_head[1]
                head_mask[l, h] = 0.0
        env = '<CHI> ' + add_tag(env_list[i].replace('The child', ''), env_tag)  # <CHI> <ENV>(holds an apple)
        lan = add_tag(lan_list[i])
        context = env + ' <CHI> ' + lan
        all_contexts=[context]  # first item is context, rest is cmp. Same length
        # rank, todo
        token_ids = tokenizer.encode(add_tag(next_token), add_special_tokens=False)
        if len(token_ids) != 1:
            raise ValueError('next_token should tokenize to exactly one token.')
        target_token_id = token_ids[0]

        # Tokenize all contexts in a batch with padding.
        encoded = tokenizer(all_contexts, return_tensors='pt', padding=False)
        input_ids = encoded['input_ids'].to(device)
        # print(input_ids)
        with torch.no_grad():
            if not only_logit:
                outputs = model(input_ids, head_mask=head_mask)[0]
            else:
                outputs = model(input_ids, head_mask=head_mask)
            # Get the logits for the next-token prediction for each context.
            # outputs.logits has shape (batch_size, seq_length, vocab_size)
            # print(outputs)
            logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            # print(logits.shape)
            log_probs = F.log_softmax(logits, dim=-1)
            surprisals = (-log_probs[:, target_token_id]).tolist()
            surprisal_list.append(surprisals[0])
            target_logp = log_probs[:, target_token_id]                     # shape (1,)
            rank = (log_probs > target_logp.unsqueeze(-1)).sum(dim=-1) + 1  # shape (1,)
            rank_list.append(int(rank.item()))

        if show_output:
            print(f"for context='{context}' next_token='{next_token}' surprisal={surprisals[0]}")
            show_output -= 1

        if i % 10 == 0 or i == 99:
            print(f'{i} completed')
    if return_rank:
        return surprisal_list, rank_list
    return surprisal_list


def infer_model_on_contexts_rev(model, word_list, env_list, cmp_env_lists, lan_list, show_output=0, env_tag=':<ENV>'):
    """Surprisal of word as next token of given env_list, cmp_env_lists (list of env list), lan_list (removed target word) (untagged); ranking compare."""
    model.eval()
    surprisal_list = []
    cmp_surprisal_list = []
    ranking_list = []
    cmp_ranking_list = []
    for i, next_token in enumerate(word_list):
        env = '<CHI> ' + add_tag(env_list[i].replace('The child', ''), env_tag)  # <CHI> <ENV>(holds an apple)
        plain_cmp_envs = cmp_env_lists[i]  # [holds an animal, holds an egg, ...]
        cmp_envs = list(map(lambda x: '<CHI> ' + add_tag(x.replace('The child', ''), env_tag), plain_cmp_envs))  # [<CHI> <ENV>(holds an animal), <CHI> <ENV>(holds an egg), ...]
        lan = add_tag(lan_list[i])
        context = '<CHI> ' + lan + ' ' + env
        all_contexts = [('<CHI> ' + lan + ' ' + cmp_env) for cmp_env in cmp_envs]
        all_contexts.insert(0, context)  # first item is context, rest is cmp. Same length
        # rank, todo
        token_ids = tokenizer.encode(add_tag(next_token, env_tag), add_special_tokens=False)
        if len(token_ids) != 1:
            raise ValueError('next_token should tokenize to exactly one token.')
        target_token_id = token_ids[0]

        # Tokenize all contexts in a batch with padding.
        encoded = tokenizer(all_contexts, return_tensors='pt', padding=False)
        input_ids = encoded['input_ids'].to(model.device)
        # print(input_ids)
        with torch.no_grad():
            outputs = model(input_ids)
            # Get the logits for the next-token prediction for each context.
            # outputs.logits has shape (batch_size, seq_length, vocab_size)
            logits = outputs.logits[:, -1, :]  # Shape: (batch_size, vocab_size)
            log_probs = F.log_softmax(logits, dim=-1)
            surprisals = (-log_probs[:, target_token_id]).tolist()
            surprisal_list.append(surprisals[0])
            cmp_surprisal_list.append(surprisals[1:])

        if show_output:
            print(f"for context='{context}' cmp_contexts='{all_contexts[1:]}' next_token='{next_token}' surprisal={surprisals[0]}, cmp_surprisal={surprisals[1:]}")
            show_output -= 1

        # For each context, get the logit for the target token.
        target_logits = logits[:, target_token_id].unsqueeze(1)  # Shape: (batch_size, 1)

        # Compute the rank: count how many tokens have a logit greater than the target token.
        # The rank is 1 plus the number of tokens with higher logit.
        ranks = ((logits > target_logits).sum(dim=1) + 1).tolist()
        ranking_list.append(ranks[0])
        cmp_ranking_list.append(ranks[1:])
        if i % 10 == 0:
            print(f'{i} completed')

    return surprisal_list, cmp_surprisal_list, ranking_list, cmp_ranking_list


def get_similarity(model, token1, token2):
    """Given model and two tokens, calculate their cosine similarity of embeddings."""
    embeddings = get_embedding(model, [token1, token2])
    cos_sim = F.cosine_similarity(embeddings[0], embeddings[1])
    return cos_sim.item()


def get_noun_word_list():
    """Get the word list of all nouns with 100+ occurance in env and lan."""
    with open('../data/noun_token_count_filtered.json') as fp:
        words = json.load(fp)
    return list(words.keys())


def get_bert_similarity_list_batch(word_list, context_list, bert_model, bert_tokenizer):
    """Use BERT to get top 5 similar words from word_list for each context.

    For each index i, the intended word is word_list[i] and the corresponding context
    (in context_list[i]) contains exactly one mask token (e.g., '[MASK]'). For each context,
    this function computes the probability distribution at the mask position and selects
    the top 5 words from word_list (excluding the intended word) with the highest probability.

    Returns:
        A list of lists, where the i-th sublist contains the top 5 words for the corresponding context.
    """
    # Tokenize all contexts in batch with padding.
    tokenized_word = bert_tokenizer.tokenize(context_list[0])
    encoded = bert_tokenizer(context_list, return_tensors='pt', padding=True, truncation=True)
    input_ids = encoded['input_ids']  # shape: [batch_size, seq_length]
    attention_mask = encoded['attention_mask']
    print(context_list[0], tokenized_word, input_ids[0], attention_mask[0], sep='\n')

    # Run the model once over the batch.
    outputs = bert_model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits  # shape: [batch_size, seq_length, vocab_size]
    probs = torch.nn.functional.softmax(logits, dim=-1)  # probability distribution over vocab

    # Find the mask token id.
    mask_token_id = bert_tokenizer.mask_token_id

    # Determine the position of the mask token for each sample.
    batch_mask_indices = []
    batch_size = input_ids.size(0)
    for i in range(batch_size):
        # Find indices where mask token appears.
        mask_positions = (input_ids[i] == mask_token_id).nonzero(as_tuple=False)
        if mask_positions.size(0) != 1:
            raise ValueError(f'Each context must contain exactly one [MASK] token, found {mask_positions.size(0)} at {i}th index')
        batch_mask_indices.append(mask_positions.item())

    # Precompute candidate token ids for the given word list.
    candidate_ids = [bert_tokenizer.convert_tokens_to_ids(word) for word in word_list]

    result = []
    # Process each sample in the batch.
    for i, intended_word in enumerate(word_list):
        mask_idx = batch_mask_indices[i]
        mask_prob_vector = probs[i, mask_idx]  # probability vector for the mask token

        # Build a list of (word, probability) for words in word_list excluding the intended word.
        candidate_probs = [
            (word, mask_prob_vector[cid].item())
            for word, cid in zip(word_list, candidate_ids)
            if word != intended_word
        ]
        # Sort the candidates in descending order of probability.
        candidate_probs.sort(key=lambda x: x[1], reverse=True)
        # Select the top 5 words.
        top_5_words = [word for word, _ in candidate_probs[:5]]
        result.append(top_5_words)

    return result


def get_bert_similarity_list_batch_prob(word_list, context_list, bert_model, bert_tokenizer):
    """Use bert to get similar words that fits into context."""
    # Tokenize all contexts in batch with padding.
    tokenized_word = bert_tokenizer.tokenize(context_list[0])
    encoded = bert_tokenizer(context_list, return_tensors='pt', padding=True, truncation=True)
    input_ids = encoded['input_ids']  # shape: [batch_size, seq_length]
    attention_mask = encoded['attention_mask']
    print(context_list[0], tokenized_word, input_ids[0], attention_mask[0], sep='\n')

    # Run the model once over the batch.
    outputs = bert_model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits  # shape: [batch_size, seq_length, vocab_size]
    probs = torch.nn.functional.softmax(logits, dim=-1)  # probability distribution over vocab

    # Find the mask token id.
    mask_token_id = bert_tokenizer.mask_token_id

    # Determine the position of the mask token for each sample.
    batch_mask_indices = []
    batch_size = input_ids.size(0)
    for i in range(batch_size):
        # Find indices where mask token appears.
        mask_positions = (input_ids[i] == mask_token_id).nonzero(as_tuple=False)
        if mask_positions.size(0) != 1:
            raise ValueError(f'Each context must contain exactly one [mask] token, found {mask_positions.size(0)} at {i}th index')
        batch_mask_indices.append(mask_positions.item())

    # Precompute candidate token ids for the given word list.
    candidate_ids = [bert_tokenizer.convert_tokens_to_ids(word) for word in word_list]

    result = []
    # Process each sample in the batch.
    for i, intended_word in enumerate(word_list):
        mask_idx = batch_mask_indices[i]
        mask_prob_vector = probs[i, mask_idx]  # probability vector for the mask token

        intended_id = bert_tokenizer.convert_tokens_to_ids(intended_word)
        intended_prob = mask_prob_vector[intended_id].item()
        threshold = 0.5 * intended_prob

        similar_words = []
        for word, cid in zip(word_list, candidate_ids):
            if word == intended_word:
                continue
            if mask_prob_vector[cid].item() >= threshold:
                similar_words.append(word)
        result.append(similar_words)

    return result


def neg_log_perplexity_before_word(model, tokenizer, sentence, word) -> float:
    """Calculate the perplexity of a sentence until and excluding the given word."""
    model.eval()

    # Tokenize the sentence and the word
    sentence_ids = tokenizer.encode(sentence, return_tensors='pt')[0]
    word_ids = tokenizer.encode(word, add_special_tokens=False)

    # Search for the first occurrence of word_ids in sentence_ids
    for i in range(len(sentence_ids) - len(word_ids) + 1):
        if torch.equal(sentence_ids[i:i+len(word_ids)], torch.tensor(word_ids)):
            if i == 0:
                return -1  # 'word' is the first token
            prefix_ids = sentence_ids[:i].unsqueeze(0)
            labels = prefix_ids.clone()

            with torch.no_grad():
                outputs = model(input_ids=prefix_ids, labels=labels)
                loss = outputs.loss  # average cross-entropy loss (in nats)
                return loss.item()  # negative log perplexity

    return -1  # 'word' not found in sentence


def neg_log_perplexity_at_word(model, tokenizer, sentence, word) -> float:
    """Calculate the perplexity of a sentence until and including the given word."""
    model.eval()
    # Tokenize the sentence and the word
    sentence_ids = tokenizer.encode(sentence, return_tensors='pt')[0]
    word_ids = tokenizer.encode(word, add_special_tokens=False)

    # Search for the first occurrence of word_ids in sentence_ids
    for i in range(len(sentence_ids) - len(word_ids) + 1):
        if torch.equal(sentence_ids[i:i+len(word_ids)], torch.tensor(word_ids)):
            if i == 0:
                return -1  # 'word' is the first token
            prefix_ids = sentence_ids[:i+1].unsqueeze(0)
            labels = prefix_ids.clone()

            with torch.no_grad():
                outputs = model(input_ids=prefix_ids, labels=labels)
                loss = outputs.loss  # average cross-entropy loss (in nats)
                return loss.item()  # negative log perplexity

    return -1  # 'word' not found in sentence


if __name__ == '__main__':
    model = checkpoint_path_to_model(files_sorted[-1])
