from nltk.util import ngrams
from nltk.tokenize import word_tokenize
import math
from collections import defaultdict, Counter
import numpy as np
import json

all_responses = ... # should contain all model responses to all queries
# all_responses[i][j] = `response of model #j to query #i`

model_names = ... # should contain all model names except the baseline model (required for keys of the output json dict)
baseline_model = ... # name of baseline llm

max_n = 10 # number of weak judges

# ---------- Helpers ---------- #

def tokenize_and_pad(sentences, n):
    tokenized = []
    for sent in sentences:
        tokens = word_tokenize(sent.lower())
        tokens = ["<s>"] * (n - 1) + tokens + ["</s>"]
        tokenized.append(tokens)
    return tokenized

def build_ngram_model(tokenized, max_n=4):
    ngram_counts_by_n = {n: defaultdict(Counter) for n in range(1, max_n + 1)}
    context_counts_by_n = {n: defaultdict(int) for n in range(1, max_n + 1)}
    vocab = set()

    for tokens in tokenized:
        vocab.update(tokens)
        for n in range(1, max_n + 1):
            for ng in ngrams(tokens, n):
                context = ng[:-1]
                target = ng[-1]
                ngram_counts_by_n[n][context][target] += 1
                context_counts_by_n[n][context] += 1

    return ngram_counts_by_n, context_counts_by_n, len(vocab)

def compute_log_likelihood(tokens, n, ngram_counts, context_counts, vocab_size):
    tokens = ["<s>"] * (n - 1) + tokens + ["</s>"]
    log_prob = 0.0
    count = 0
    for ng in ngrams(tokens, n):
        context = ng[:-1]
        target = ng[-1]
        context_count = context_counts[n].get(context, 0)
        target_count = ngram_counts[n][context].get(target, 0)

        prob = (target_count) / (context_count)
        log_prob += math.log(prob)
        count += 1
    return log_prob / count

# ---------- Main Pipeline ---------- #

all_log_likelihoods = []

for query_idx, responses in enumerate(all_responses):
    # Tokenize all responses for this query and build model
    tokenized = tokenize_and_pad(responses, max_n)
    ngram_counts, context_counts, vocab_size = build_ngram_model(tokenized, max_n)
    query_scores = []
    for resp in responses:
        tokens = word_tokenize(resp.lower())
        resp_scores = [0] * max_n
        for n in range(1, max_n + 1):
            ll = compute_log_likelihood(tokens, n, ngram_counts, context_counts, vocab_size)
            resp_scores[n-1] = ll
        query_scores.append(resp_scores)
    all_log_likelihoods.append(query_scores)
    print(f"Query {query_idx + 1} processed.")

## ---------- Display Results ---------- #
#
#for q_idx, responses in enumerate(all_responses):
#    print(f"\n=== Query {q_idx + 1} ===")
#    for r_idx, resp in enumerate(responses):
#        print(f"\nResponse {r_idx + 1}:")
#        for n in range(1, max_n + 1):
#            ll = all_log_likelihoods[q_idx][r_idx][n-1]
#            print(f"  {n}-gram Log-Likelihood: {ll:.4f}")

all_log_likelihoods = np.array(all_log_likelihoods)
all_weak_judge_out = ((all_log_likelihoods > all_log_likelihoods[:,[-1],:]).astype(float) - (all_log_likelihoods < all_log_likelihoods[:,[-1],:]).astype(float))

weak_judge_out = {}
for j,model_name in enumerate(model_names):
    weak_judge_out[model_name] = all_weak_judge_out[:,j,:].tolist()
weak_judge_out[baseline_model] = None

json.dump(weak_judge_out, open("weak_judge.json", 'w'))