import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_llm(model_name: str, hf_token: str | None = None):
    print(f"Loading: {model_name}...")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)

    model = (AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=hf_token).to(device).eval())

    return tokenizer, model, device

def load_existing_likelihoods(file_path):
    if os.path.exists(file_path):
        with open(file_path, "r") as f:
            existing_data = json.load(f)
            return existing_data
    return []

def load_model_outputs(file_path):
    if os.path.exists(file_path):
        with open(file_path, "r") as f:
            model_outputs = [json.loads(line) for line in f]
            return model_outputs
    else:
        print(f"Error: {file_path} not found.")
        return None
    
def load_topk_jsonl(path, num_tests):
    selected = [[] for _ in range(num_tests)]     
    with open(path, encoding="utf-8") as f:
        for ln in f:
            rec = json.loads(ln)
            idx = rec["test_idx"]
            selected[idx] = rec["fewshot_topk"]
    return selected
    
def calculate_diffs(current_likelihoods, sel_idx):
    sel_idx = set(sel_idx)

    def mean_list(l):
        return (sum(l) / len(l)) if l else 0.0

    def compute_means(entry_list, idx_filter=None):
        ce_vals = []
        for idx, e in enumerate(entry_list):
            if isinstance(e, (list, tuple)):
                e = e[0]
            if (idx_filter is not None) and (idx not in idx_filter):
                continue
            ce_vals.append(mean_list(e.get("ce_losses", [])))
        ce_mean = sum(ce_vals) / len(ce_vals) if ce_vals else 0.0
        return ce_mean

    rep_list    = current_likelihoods.get("replace_likelihoods", [])
    zer_list    = current_likelihoods.get("zero_likelihoods",    [])

    rep_ce_sel = compute_means(rep_list,    sel_idx)
    zer_ce_sel = compute_means(zer_list,    sel_idx)

    current_likelihoods["replace_ce_diff"] = rep_ce_sel - zer_ce_sel

def calculate_cross_entropy_loss_with_topk(prompt, answer, model, tokenizer, device, top_k=2):
    combined_input = prompt + answer
    combined_ids = tokenizer(combined_input, return_tensors="pt").input_ids.to(device)

    prompt_ids = tokenizer(prompt, return_tensors="pt").input_ids
    prompt_len = prompt_ids.shape[1]

    labels = combined_ids.clone()
    labels[:, :prompt_len] = -100

    with torch.no_grad():
        outputs = model(input_ids=combined_ids, labels=labels)
        ce_loss = outputs.loss               
        logits = outputs.logits              # shape: [batch_size, seq_len, vocab_size]

    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    shift_logits_2d = shift_logits.view(-1, model.config.vocab_size)
    shift_labels_1d = shift_labels.view(-1)

    loss_fct_none = nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
    token_wise_ce_loss = loss_fct_none(shift_logits_2d, shift_labels_1d)

    valid_mask = (shift_labels_1d != -100)
    valid_ce = token_wise_ce_loss[valid_mask]         # [N_valid]
    valid_logits = shift_logits_2d[valid_mask]        # [N_valid, vocab_size]
    valid_labels = shift_labels_1d[valid_mask]        # [N_valid]

    valid_probs = F.softmax(valid_logits, dim=-1)     # [N_valid, vocab_size]
    valid_entropy = -torch.sum(
        valid_probs * torch.log(valid_probs + 1e-10),
        dim=-1
    )  # [N_valid]

    summed_ce_loss = valid_ce.sum()
    valid_tokens = valid_mask.sum()
    manual_ce_loss = summed_ce_loss / valid_tokens  

    if not torch.isclose(torch.tensor(manual_ce_loss.item()), torch.tensor(ce_loss.item()), atol=1e-6):
        print("Loss mismatch!")
        print(f"Auto loss: {ce_loss.item()}, Manual loss: {manual_ce_loss.item()}")

    topk_tokens_list = []
    topk_probs_list = []
    for i in range(valid_probs.shape[0]):

        tk_vals, tk_ids = torch.topk(valid_probs[i], k=top_k)
        tk_probs = tk_vals.tolist()  # [top_k]
        tk_ids = tk_ids.tolist()     # [top_k]

        tk_tokens = tokenizer.convert_ids_to_tokens(tk_ids)
        
        topk_tokens_list.append(tk_tokens)
        topk_probs_list.append(tk_probs)

    answer_label_probs = []
    for i in range(valid_probs.shape[0]):
        label_id = valid_labels[i].item()
        prob_for_label = valid_probs[i, label_id].item()  
        answer_label_probs.append(prob_for_label)

    result = {
        "ce_losses": valid_ce.tolist(),  
        "entropy": valid_entropy.tolist(),
        "answer_labels": tokenizer.convert_ids_to_tokens(valid_labels.tolist()),
        # "answer_label_probs": answer_label_probs,  # for CoT-WP
        # "topk_probs": topk_probs_list
    }

    return result

def process_scored_file(scored_path: str, _type: str):
    with open(scored_path, 'r', encoding='utf-8') as f:
        scored_entries = [json.loads(line) for line in f]

    if not scored_entries:
        print(f"Scored file {scored_path} is empty.\n")
        return

    first = scored_entries[0]["is_correct"]

    if isinstance(first, list) and first and isinstance(first[0], bool):
        num_questions = len(scored_entries)
        n_repeats = len(first)

        scores = [[] for _ in range(n_repeats)]
        any_correct = []

        for entry in scored_entries:
            results = entry["is_correct"]
            for i, r in enumerate(results):
                scores[i].append(r)
            any_correct.append(any(results))

        rep_accuracies = [
            sum(scores[i]) / num_questions for i in range(n_repeats)
        ]
        avg_rep_acc = sum(rep_accuracies) / n_repeats
        any_correct_acc = sum(any_correct) / num_questions

        for i, acc_val in enumerate(rep_accuracies):
            print(f"{i}:  {acc_val:.3f}")
        print(f"{_type} Avg: {avg_rep_acc:.3f}")
        print(f"Any_correct: {any_correct_acc:.3f}\n")

    elif isinstance(first, list) and first and isinstance(first[0], (list, tuple)) and len(first[0]) == 2:
        n_repeats = len(first)
        num_questions = len(scored_entries)

        em_scores = [[] for _ in range(n_repeats)]
        f1_scores = [[] for _ in range(n_repeats)]

        for entry in scored_entries:
            results = entry["is_correct"]
            for i, (em, f1) in enumerate(results):
                em_scores[i].append(em)
                f1_scores[i].append(f1)

        rep_em = [sum(em_scores[i]) / num_questions for i in range(n_repeats)]
        rep_f1 = [sum(f1_scores[i]) / num_questions for i in range(n_repeats)]

        avg_em = sum(rep_em) / n_repeats
        avg_f1 = sum(rep_f1) / n_repeats

        for i in range(n_repeats):
            print(f"{i}: EM={rep_em[i]:.3f}, F1={rep_f1[i]:.3f}")
        print(f"{_type} Avg EM: {avg_em:.3f}, Avg F1: {avg_f1:.3f}\n")

    else:
        print(f"Unrecognized format for is_correct in {scored_path}")