"""
Run DINCO on TriviaQA
"""

import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from datasets import load_dataset
import pickle as pkl
import re

PROMPT = """
Here are 2 sets of example prompt and answer.

Example Prompt: Which American-born Sinclair won the Nobel Prize for Literature in 1930?
Example Answer: Sinclair Lewis

Example Prompt: Where in England was Dame Judi Dench born?
Example Answer: York

---

Now, here is a new prompt to answer. Answer with a concise phrase, as in the examples.

Prompt: {question}
Answer:
""".strip()

PTRUE_PROMPT = """
Below is a question and a candidate answer. Your task is to determine whether the answer is correct or not. Only output \"Yes\" (correct) or \"No\" (incorrect).

Question: {question}
Candidate answer: {candidate_answer}
""".strip()

def beam_search(ds, model, tokenizer, num_beams=10, num_return_sequences=10, length_penalty=0.0, max_new_tokens=100):
    beam_strs = []
    beam_lls = torch.zeros((len(ds), num_return_sequences))

    for ex_i, ex in enumerate(ds):
        msg = [
            {'role': 'user', 'content': PROMPT.format(question=ex['question'])}
        ]
        input_ids = tokenizer.apply_chat_template(
            [msg],
            enable_thinking=False,
            add_generation_prompt=True,
            padding=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_ids,
            num_beams=num_beams,
            num_return_sequences=num_return_sequences,
            max_new_tokens=max_new_tokens,
            length_penalty=length_penalty,
            output_scores=True,
            return_dict_in_generate=True
        )

        beam_strs.append(tokenizer.batch_decode(outputs.sequences[:, input_ids.shape[1]:], skip_special_tokens=True))
        beam_lls[ex_i] = outputs.sequences_scores.cpu()

    return beam_strs, beam_lls

def clean_str(s):
    s = s.split('\n')[0]
    s = s.replace('Answer:', '')
    s = s.strip()
    s = re.sub(r'\s+', ' ', s)
    return s

def lexical_cleaning(beam_strs, beam_lls):
    filtered_beam_lls = -float('Inf') * torch.ones_like(beam_lls)

    for ex_i in range(len(beam_strs)):
        # argsort beam_lls[ex_i] in descending order
        top_ll_is = torch.topk(beam_lls[ex_i], k=len(beam_lls[ex_i])).indices

        strs = []
        norm_strs = set()
        for seq_i in top_ll_is:
            s = clean_str(beam_strs[ex_i][seq_i])

            # skip if empty
            if len(s) == 0:
                continue

            # get normalized version
            norm_s = s.lower()
            norm_s = norm_s.replace('.', '')

            # add if first occurrence
            if norm_s not in norm_strs:
                norm_strs.add(norm_s)
                strs.append(s)
                filtered_beam_lls[ex_i, len(strs) - 1] = beam_lls[ex_i, seq_i]

        # replace with filtered strings
        beam_strs[ex_i] = strs

    return beam_strs, filtered_beam_lls

def get_ptrue(ds, model, tokenizer, beam_strs):
    ptrues = -torch.ones(n_qst, max([len(strs) for strs in beam_strs]))

    option_tok_ids = tokenizer.convert_tokens_to_ids(['Yes', 'No'])

    for ex_i, ex in enumerate(ds):
        for ans_i, ans in enumerate(beam_strs[ex_i]):
            msg = [
                {'role': 'user', 'content': PTRUE_PROMPT.format(question=ex['question'], candidate_answer=ans)},
            ]
            seq = tokenizer.apply_chat_template(
                [msg],
                enable_thinking=False,
                add_generation_prompt=True,
                padding=True,
                return_tensors="pt"
            ).to(model.device)

            with torch.no_grad():
                outputs = model(
                    seq,
                    logits_to_keep=1,
                )
            ptrues[ex_i, ans_i] = torch.softmax(outputs.logits[0, -1, option_tok_ids], dim=-1)[0]

    return ptrues

def run_nli(ds, beam_strs):
    nli_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
    nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
    nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name, device_map='auto')

    beam_width = max([len(strs) for strs in beam_strs])
    nlis = -torch.ones(len(ds), beam_width, beam_width, 3)

    for ex_i, ex in enumerate(ds):
        for i in range(len(beam_strs[ex_i])):
            for j in range(len(beam_strs[ex_i])):
                if i == j: continue

                input_ids = nli_tokenizer(
                    f"Question: {ex['question']}\nAnswer: {beam_strs[ex_i][i]}",
                    f"Answer: {beam_strs[ex_i][j]}",
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512,
                ).input_ids.to(nli_model.device)
                with torch.no_grad():
                    nlis[ex_i, i, j] = torch.softmax(nli_model(input_ids).logits, dim=-1)[0].cpu()

    return nlis

def get_normalized_verbalized_confidence(ptrues, nlis):
    entail_i = 0
    contra_i = 2

    sym_nlis = (nlis + nlis.swapdims(1, 2)) / 2
    contra_weights = sym_nlis[:, :, :, contra_i]
    sims = nlis[:, :, :, entail_i]
    degrees = torch.sum(torch.max(torch.tensor(0.), sims), dim=2) + 1

    n_qst = ptrues.shape[0]
    nvcs = torch.empty(n_qst)

    for ex_i in range(n_qst):
        main_ans_i = 0 # set main answer to highest probability generation
        numerator = ptrues[ex_i, main_ans_i]
        denominator = numerator.clone()
        for ans_i in range(ptrues.shape[1]):
            if ans_i == main_ans_i:
                continue
            if ptrues[ex_i, ans_i] < 0:
                break

            denominator += ptrues[ex_i, ans_i] * contra_weights[ex_i, main_ans_i, ans_i] / (degrees[ex_i, ans_i] - sims[ex_i, main_ans_i, ans_i])

        if denominator > 1:
            nvcs[ex_i] = numerator / denominator
        else:
            nvcs[ex_i] = numerator

    return nvcs

def sample_generations(ds, model, tokenizer, n_sample=10, max_new_tokens=100):
    sampled_strs = []

    for ex_i, ex in enumerate(ds):
        msg = [
            {'role': 'user', 'content': PROMPT.format(question=ex['question'])}
        ]
        input_ids = tokenizer.apply_chat_template(
            [msg],
            enable_thinking=False,
            add_generation_prompt=True,
            padding=True,
            return_tensors="pt"
        ).to(model.device)

        outputs = model.generate(
            input_ids.repeat(n_sample, 1),
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True
        )
        sampled_strs.append(tokenizer.batch_decode(outputs.sequences[:, input_ids.shape[1]:], skip_special_tokens=True))

    return sampled_strs

def run_sc_nli(main_strs, sampled_strs):
    nli_model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
    nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
    nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name, device_map='auto')
    entail_i = 0

    n_qst = len(main_strs)
    n_sample = max([len(strs) for strs in sampled_strs])
    nlis = torch.zeros(n_qst, n_sample)

    for ex_i, ex in enumerate(ds):
        main_str = clean_str(main_strs[ex_i])
        assert(len(sampled_strs[ex_i]) == n_sample)

        for sample_i, sampled_str in enumerate(sampled_strs[ex_i]):
            ans_pair = [main_str, clean_str(sampled_str)]
            for perm in [(0, 1), (1, 0)]:
                input_ids = nli_tokenizer(
                    f"Question: {ex['question']}\nAnswer: {ans_pair[perm[0]]}",
                    f"Answer: {ans_pair[perm[1]]}",
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512,
                ).input_ids.to(nli_model.device)

                with torch.no_grad():
                    nlis[ex_i, sample_i] += torch.softmax(nli_model(input_ids).logits, dim=-1)[0, entail_i].cpu()

            nlis[ex_i, sample_i] /= 2

    return nlis

if __name__ == "__main__":
    # load dataset
    ds = load_dataset("mandarjoshi/trivia_qa", "rc.nocontext")
    n_qst = 1
    ds = ds['validation'].shuffle(seed=17).select(range(n_qst))

    # load model
    model_name = 'Qwen/Qwen3-1.7B'
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', dtype=torch.float16)

    # beam search
    beam_strs, beam_lls = beam_search(ds, model, tokenizer)
    # save
    with open('beam_strs.pkl', 'wb') as f:
        pkl.dump(beam_strs, f)
    torch.save(beam_lls, 'beam_lls.pth')
    print('Saved beam search results', flush=True)

    # perform lexical cleaning
    beam_strs, beam_lls = lexical_cleaning(beam_strs, beam_lls)
    # save
    with open('beam_strs_cleaned.pkl', 'wb') as f:
        pkl.dump(beam_strs, f)
    torch.save(beam_lls, 'beam_lls_cleaned.pth')
    print('Saved cleaned beam search results', flush=True)

    # collect verbalized confidence
    ptrues = get_ptrue(ds, model, tokenizer, beam_strs)
    # save
    torch.save(ptrues, 'ptrues.pth')
    print('Saved P(true) results', flush=True)

    # get NLI predictions
    nlis = run_nli(ds, beam_strs)
    # save
    torch.save(nlis, 'nlis.pth')
    print('Saved NLI results', flush=True)

    # compute normalized verbalized confidence
    nvcs = get_normalized_verbalized_confidence(ptrues, nlis)
    # save
    torch.save(nvcs, 'nvcs.pth')
    print('Saved NVC results', flush=True)

    # run self-consistency
    sampled_strs = sample_generations(ds, model, tokenizer)
    # save
    with open('sampled_strs.pkl', 'wb') as f:
        pkl.dump(sampled_strs, f)
    print('Saved sampled generations', flush=True)

    # compute self-consistency matches
    main_strs = [strs[0] for strs in beam_strs]
    sc_nlis = run_sc_nli(main_strs, sampled_strs)
    # save
    torch.save(sc_nlis, 'sc_nlis.pth')
    print('Saved self-consistency NLI results', flush=True)

    # compute SC confidences
    # append 1 to dim=1 for main answer
    sc_nlis = torch.cat((sc_nlis, torch.ones(n_qst, 1)), dim=1)
    sc_match_threshold = 0.9
    sc_confs = torch.mean((sc_nlis > sc_match_threshold).float(), dim=1)

    # compute NVC and SC
    dinco_confs = (nvcs + sc_confs) / 2
    # save
    torch.save(dinco_confs, 'dinco_confs.pth')
    print('Saved DINCO results', flush=True)
