"""
Contains evaluation utilities for pytorch-based rewriting methods.
To use, simply call `compute_rewrite_quality_counterfact` with the
appropriate arguments, which returns a dictionary containing them.
"""

import typing
from itertools import chain

import nltk
import numpy as np
import scipy
import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from dsets import AttributeSnippets
from util.generate import generate_fast
from util.perplexity import perplexity
from glue_eval.mmlu_eval_each import MMLUEval

def apply_template(tok, QAs, add_qa_prompt=False):
    if 'instruct' not in tok.name_or_path.lower():
        return_data = ''
        for q, a in QAs[:-1]:
            if a[-1] != '.':
                a += '.'
            return_data += f'{q} {a}\n'
        return_data += QAs[-1][0]
        return return_data
    else:
        qa_prompt = 'Question: {}\nAnswer: ' if add_qa_prompt else '{}'
        result = []
        # result.append({'role': 'system', 'content': 'Answer following questions concisely.'})
        if 'qwen' in tok.name_or_path.lower():
            result.append({'role': 'system', 'content': 'Answer following questions concisely.'})
        for idx, (question, answer) in enumerate(QAs):
            result.append({'role': 'user', 'content': qa_prompt.format(question)})
            if idx != len(QAs) - 1:
                result.append({'role': 'assistant', 'content': answer})
        return_data = tok.apply_chat_template(result, tokenize=False)
        if 'qwen' in tok.name_or_path.lower():
            return_data = return_data + '<|im_start|>assistant\n'
        else:
            return_data = return_data + '<|start_header_id|>assistant<|end_header_id|>\n\n'
        return return_data

def compute_mmlu(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    record: typing.Dict,
):
    mmlu_eval = MMLUEval(model, tok, number_of_tests=2, begin_id=record['case_id'])
    result_dict, stored_generations = mmlu_eval.evaluate()
    return result_dict['f1_new'] # if 'qwen' in tok.name_or_path.lower() else result_dict['f1']

def compute_neighborhood(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    record: typing.Dict,
):
    prompts = [apply_template(tok, [[item, '']], add_qa_prompt=False) for item in record['neighborhood']]
    # prompts = record['neighborhood']
    prompts_inp = tok(prompts, padding=True, return_tensors="pt").to("cuda")
    pred_idx = prompts_inp['attention_mask'].sum(dim=1) - 1

    with torch.no_grad():
        logits = model(**prompts_inp).logits
        pred_logits = logits[torch.arange(logits.size(0)), pred_idx, :]
        pred_probs = torch.nn.functional.softmax(pred_logits, dim=-1)

    true_str = record['requested_rewrite']['target_true']['str']
    ture_token = tok(' ' + true_str if true_str[0] != ' ' else true_str)['input_ids'][0]
    true_probs = pred_probs[:, ture_token]

    new_str = record['requested_rewrite']['target_new']['str']
    new_token = tok(' ' + new_str if new_str[0] != ' ' else new_str)['input_ids'][0]
    new_probs = pred_probs[:, new_token]

    return (true_probs > new_probs).float().mean().item()

def compute_akew(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    record: typing.Dict,
    snips: AttributeSnippets,
    vec: TfidfVectorizer,
) -> typing.Dict:
    QA_prompt = [
        apply_template(tok, s+[c], add_qa_prompt=False)
        for s, c in zip(record['few_shot'], record['eval_data'])
    ]
    gen_cnt = 50 if 'qwen' in tok.name_or_path.lower() else 20
    QA_generation = generate_fast(
        model,
        tok,
        QA_prompt,
        top_k=1,
        max_out_len=tok(QA_prompt, return_tensors='pt', padding=True)['input_ids'].size(1) + gen_cnt,
    )
    generation_result = [g.replace(p, '') for g, p in zip(QA_generation, QA_prompt)]
    neighborhood_result = compute_neighborhood(model, tok, record)
    mmlu_result = compute_mmlu(model, tok, record)

    return {
        'generation_result': generation_result,
        'neighborhood_result': neighborhood_result,
        'mmlu_result': mmlu_result
    }

def compute_probs(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    record: typing.Dict,
    snips: AttributeSnippets,
    vec: TfidfVectorizer,
) -> typing.Dict:
    def calc(ques, ans):
        with torch.no_grad():
            target_ids = [tok(a, return_tensors="pt").to("cuda")["input_ids"][0] for a in ans]
            prompts = [q + tok.decode(tids[:-1]) for q, tids in zip(ques, target_ids)]
            input_tok = tok(prompts, return_tensors="pt", padding=True).to("cuda")
            rewriting_targets = torch.tensor(-100, device="cuda").repeat(len(prompts), *input_tok["input_ids"].shape[1:])
            for i in range(len(prompts)):
                ex_len = input_tok["attention_mask"][i].sum()
                rewriting_targets[i, ex_len - len(target_ids[i]) : ex_len] = target_ids[i]
            logits = model(**input_tok).logits
            log_p = torch.gather(
                torch.log_softmax(logits, dim=2),
                2,
                torch.where(rewriting_targets != -100, rewriting_targets, 0).unsqueeze(2),
            ).squeeze(2)
            mask = (rewriting_targets != -100).float()
            probs = (log_p * mask).sum(1).exp()
            return probs.cpu().tolist()

    pre_probs = calc(
        [item[0] for item in record['eval_data']],
        [f' {item[1]}' for item in record['eval_data']]
    )
    post_probs = calc(
        [f'{prefix} {item[0]}' if prefix else item[0] for prefix, item in zip(record['prefix'], record['eval_data'])],
        [f' {item[1]}' for item in record['eval_data']]
    )
    return {
        'pre_probs': pre_probs,
        'post_probs': post_probs
    }


def test_batch_prediction(
    model,
    tok,
    prefixes: typing.List[str],
    which_correct: str,
    target_new: str,
    target_true: str,
):
    """
    which_correct: Which target to consider correct. Either 0 for "new" or 1 for "true".
    """

    prefix_lens = [len(n) for n in tok(prefixes)["input_ids"]]
    prompt_tok = tok(
        [
            f"{prefix} {suffix}"
            for prefix in prefixes
            for suffix in [target_new, target_true]
        ],
        padding=True,
        return_tensors="pt",
    ).to("cuda")

    a_tok, b_tok = (tok(f" {n}")["input_ids"] for n in [target_new, target_true])
    choice_a_len, choice_b_len = (len(n) for n in [a_tok, b_tok])

    with torch.no_grad():
        logits = model(**prompt_tok).logits

    probs = np.zeros((logits.size(0),), dtype=np.float32)
    targets_correct = []

    for i in range(logits.size(0)):
        cur_len = choice_a_len if i % 2 == 0 else choice_b_len

        # Compute suffix probabilities
        for j in range(cur_len):
            cur_tok = (a_tok if i % 2 == 0 else b_tok)[j]
            probs[i] += -torch.nn.functional.log_softmax(
                logits[i, prefix_lens[i // 2] + j - 1, :], dim=0
            )[cur_tok].item()
        probs[i] /= cur_len

        # Compute accuracy on new targets
        if (which_correct[i // 2] == 0 and i % 2 == 0) or (
            which_correct[i // 2] == 1 and i % 2 == 1
        ):
            correct = True
            for j in range(cur_len):
                cur_tok = (a_tok if i % 2 == 0 else b_tok)[j]

                if logits[i, prefix_lens[i // 2] + j - 1, :].argmax().item() != cur_tok:
                    correct = False
                    break
            targets_correct.append(correct)

    return [
        {"target_new": probs[i].item(), "target_true": probs[i + 1].item()}
        for i in range(0, len(probs), 2)
    ], targets_correct


def test_generation(
    model,
    tok,
    prefixes: typing.List[str],
    consistency_texts: typing.List[str],
    essence_texts: typing.List[str],
    vec: TfidfVectorizer,
):
    gen_texts = generate_fast(
        model,
        tok,
        prefixes,
        n_gen_per_prompt=1,
        max_out_len=100,
    )

    ngram_entropy = n_gram_entropy(gen_texts)
    consistency_tfidf = tfidf_similarity(
        " ".join(gen_texts), " ".join(consistency_texts), vec
    )

    ret = {
        "ngram_entropy": ngram_entropy,
        "reference_score": consistency_tfidf,
        "text": gen_texts,
    }

    if len(essence_texts) > 0:
        ppl = perplexity(model, tok, " ".join(essence_texts), max_input_length=100)
        ret.update({"essence_score": ppl, "essence_text": essence_texts})

    return ret


def n_gram_entropy(gen_texts, agg="arith"):
    assert agg in ["arith", "geom"]

    return (scipy.stats.mstats.gmean if agg == "geom" else np.mean)(
        [compute_n_gram_entropy(txt) for txt in gen_texts]
    ).item()


def compute_n_gram_entropy(sentence, ns=None, weights=None, agg="arith"):
    if ns is None:
        ns = [2, 3]
    if weights is None:
        weights = [2 / 3, 4 / 3]
    assert agg in ["arith", "geom"]

    entropy_list = []
    for n in ns:
        fdist = compute_freq(sentence, n)
        freqs = np.array([freq for _, freq in fdist.items()])
        freqs = freqs / freqs.sum()

        entropy_list.append(np.sum(-freqs * np.log(freqs) / np.log(2)))

    entropy_list = np.array(entropy_list) * np.array(weights)

    return (scipy.stats.mstats.gmean if agg == "geom" else np.mean)(entropy_list)


def compute_freq(sentence, n=2):
    tokens = nltk.word_tokenize(sentence)
    ngrams = nltk.ngrams(tokens, n)
    return nltk.FreqDist(ngrams)


def tfidf_similarity(text_a, text_b, vec):
    encs = vec.transform([text_a, text_b]).A
    norm = np.linalg.norm
    return (np.dot(encs[0], encs[1]) / norm(encs[0]) / norm(encs[1])).item()
