from generate import generate
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import numpy as np
from likelihood_data import data


@torch.no_grad()
def get_likelihood(step_map, model, prompt, gen_length=128, block_length=32, mask_id=126336):
    """
    Re-run the model along the recorded generation trajectory and record the true
    likelihoods for the tokens that were actually chosen at each step.
    """
    x = torch.full(
        (prompt.shape[0], prompt.shape[1] + gen_length),
        mask_id,
        dtype=torch.long,
        device=model.device,
    )
    x[:, :prompt.shape[1]] = prompt.clone()
    num_blocks = len(step_map)
    for num_block in range(num_blocks):
        mask_index = (x == mask_id)
        logits = model(x, attention_mask=None).logits

        idx_list = step_map[num_block]['idx_list']
        token_list = step_map[num_block]['tok_list']

        idx = torch.as_tensor(idx_list, device=logits.device, dtype=torch.long).reshape(-1)
        tok = torch.as_tensor(token_list, device=logits.device, dtype=torch.long).reshape(-1)

        logits_pos = logits[0, idx, :]  # [num_positions, vocab_size]
        p_pos = F.softmax(logits_pos.float(), dim=-1)
        tok_prob = p_pos.gather(1, tok.unsqueeze(1)).squeeze(1)

        x[0, idx] = tok
        step_map[num_block]['likelihood'] = tok_prob.cpu().numpy()

    return x, step_map


def get_correction_factor(likelihood_map, stop_token_ids=None, device='cuda'):
    """
    Compute W = log(conf) - log(likelihood) per token, but ignore positions
    where the token is a stop token (EOS, PAD, etc) in both sums and means.
    """
    if stop_token_ids is None:
        stop_token_ids = []
    stop_token_ids = list(stop_token_ids)

    stop_tensor = (
        torch.as_tensor(stop_token_ids, device=device, dtype=torch.long)
        if len(stop_token_ids) > 0 else None
    )

    num_blocks = len(likelihood_map)
    block_len = len(likelihood_map[0]['conf'])

    block_corrections = torch.zeros(num_blocks, block_len, device=device, dtype=torch.float)
    valid_mask = torch.ones(num_blocks, block_len, device=device, dtype=torch.bool)

    for num_block in range(num_blocks):
        conf = torch.as_tensor(
            likelihood_map[num_block]['conf'],
            device=device,
            dtype=torch.float
        ).reshape(-1)

        likelihood = torch.as_tensor(
            likelihood_map[num_block]['likelihood'],
            device=device,
            dtype=torch.float
        ).reshape(-1)

        corrections = torch.log(conf) - torch.log(likelihood)

        if stop_tensor is not None:
            tok_ids = torch.as_tensor(
                likelihood_map[num_block]['tok_list'],
                device=device,
                dtype=torch.long
            ).reshape(-1)

            # stop_mask[j] = True iff tok_ids[j] is a stop token
            stop_mask = (tok_ids.unsqueeze(1) == stop_tensor.unsqueeze(0)).any(dim=1)

            # zero out corrections for plotting / raw grid
            corrections[stop_mask] = 0.0
            # mark them as invalid so they don't count in means
            valid_mask[num_block, stop_mask] = False

        block_corrections[num_block] = corrections

    # counts of non-stop tokens per block
    valid_counts = valid_mask.sum(dim=1).clamp(min=1)

    # per-block stats ignoring stops
    per_block_correction = (block_corrections * valid_mask).sum(dim=1)
    per_block_mean_correction = per_block_correction / valid_counts

    # dataset-level stats ignoring stops
    total_correction = (block_corrections * valid_mask).sum()
    total_valid = valid_mask.sum().clamp(min=1)
    mean_correction = total_correction / total_valid

    return {
        'correction_across_time': block_corrections.cpu().numpy(),
        'per_block_mean_correction': per_block_mean_correction.cpu().numpy(),
        'per_block_correction': per_block_correction.cpu().numpy(),
        'total_correction': total_correction.cpu().numpy(),
        'mean_correction': mean_correction.cpu().numpy()
    }



@torch.no_grad()
def likelihood_vs_true_prob(
    block_length=128,
    gen_length=128,
    temperature=0.0,
    cfg_scale=0.0,
    parallel_tokens=8,
    test_data='gsm8k'
):
    device = 'cuda'

    model = AutoModel.from_pretrained(
        'GSAI-ML/LLaDA-8B-Instruct',
        trust_remote_code=True,
        torch_dtype=torch.bfloat16
    ).to(device).eval()

    tokenizer = AutoTokenizer.from_pretrained(
        'GSAI-ML/LLaDA-8B-Instruct',
        trust_remote_code=True
    )

    if tokenizer.padding_side != 'left':
        tokenizer.padding_side = 'left'

    # ---- define stop tokens for your model ----
    stop_token_ids = set()
    if tokenizer.eos_token_id is not None:
        stop_token_ids.add(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        stop_token_ids.add(tokenizer.pad_token_id)
    # If you have any custom stop tokens, add them here, e.g.:
    # special_stops = ["<|eot_id|>", "<|eom_id|>"]
    # stop_token_ids.update(tokenizer.convert_tokens_to_ids(special_stops))

    dataset = data[test_data]   # use full split
    all_results = []

    # for dataset-level aggregation
    total_corrections = []
    mean_corrections = []
    per_block_mean_list = []
    per_block_corr_list = []

    for i, sample in enumerate(dataset):
        print(f'\n================ Example {i} ================')

        # build chat prompt for this sample
        messages = [{"role": "user", "content": sample['question']}]
        prompt_str = tokenizer.apply_chat_template(
            [messages[0]], add_generation_prompt=True, tokenize=False
        )

        encoded_outputs = tokenizer(
            [prompt_str],
            add_special_tokens=False,
            padding=True,
            return_tensors="pt"
        )
        input_ids = encoded_outputs['input_ids'].to(device)
        attention_mask = encoded_outputs['attention_mask'].to(device)

        # run generation with likelihood analysis
        out, step_map = generate(
            model,
            input_ids,
            attention_mask,
            parallel_tokens=parallel_tokens,
            gen_length=gen_length,
            block_length=block_length,
            temperature=temperature,
            cfg_scale=cfg_scale,
            tokenizer=tokenizer,
            log_to_file=False,
            do_likelihood_analysis=True
        )

        # compute likelihoods for the same trajectory
        _, likelihood_map = get_likelihood(
            step_map=step_map,
            model=model,
            prompt=input_ids,
            gen_length=gen_length,
        )

        # decode outputs (optional sanity checks)
        output = tokenizer.batch_decode(
            out[:, input_ids.shape[1]:],
            skip_special_tokens=True
        )
        sanity_op = tokenizer.batch_decode(
            _[:, input_ids.shape[1]:],
            skip_special_tokens=True
        )
        # You can print these if you want to sanity-check:
        # print('Model output:', output[0])
        # print('Sanity decode:', sanity_op[0])

        # compute correction factor stats for this example (ignoring stop tokens)
        corr_stats = get_correction_factor(
            likelihood_map=likelihood_map,
            #stop_token_ids=stop_token_ids,
            device=device
        )

        # collect per-example stats
        all_results.append({
            'index': i,
            'question': sample['question'],
            'correction_stats': corr_stats,
        })

        # collect for dataset-level mean
        total_corrections.append(corr_stats['total_correction'])
        mean_corrections.append(corr_stats['mean_correction'])
        per_block_mean_list.append(corr_stats['per_block_mean_correction'])
        per_block_corr_list.append(corr_stats['per_block_correction'])

    # ---- aggregate over dataset ----
    total_corrections = np.array(total_corrections, dtype=np.float64)
    mean_corrections = np.array(mean_corrections, dtype=np.float64)
    per_block_mean_arr = np.stack(per_block_mean_list, axis=0)   # [N, num_blocks]
    per_block_corr_arr = np.stack(per_block_corr_list, axis=0)   # [N, num_blocks]

    dataset_stats = {
        'mean_total_correction': float(total_corrections.mean()),
        'mean_mean_correction': float(mean_corrections.mean()),
        'mean_per_block_mean_correction': per_block_mean_arr.mean(axis=0),
        'mean_per_block_correction': per_block_corr_arr.mean(axis=0),
    }

    print('\n================ DATASET-LEVEL STATS ================')
    print(dataset_stats)

    return {
        'per_example': all_results,
        'dataset_stats': dataset_stats,
    }


if __name__ == "__main__":
    likelihood_vs_true_prob()
