import argparse
import gc
import os
import time
from typing import Any, List, Optional, Tuple
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import json

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from examples.config import DATA_DIR_CACHE
from examples.data import DatasetLoader, get_canaries
from examples.utils import keep_question_answer, print_gpu_utilization, split_text
from examples.metrics import plot_histogram
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from datasets import DatasetDict, Dataset

def z_div(lp0, lp1, lq, relative):
    lp0 = normalize_logits(lp0, relative)
    lp1 = normalize_logits(lp1, relative)
    lq = normalize_logits(lq, relative)
    lr0 = torch.min(lp0, lq)
    lr1 = torch.min(lp1, lq)
    lz0 = torch.logsumexp(lr0, dim=-1)
    lz1 = torch.logsumexp(lr1, dim=-1)
    return lz0 - lz1

def log_rel_bound_tde(lp1, lp0, lq, labels):
    lq = rlogits(lq)
    lp0 = rlogits(lp0)
    lp1 = rlogits(lp1)
    kx = log_rel_bound(lp1, lq)
    lr1 = rlogits(torch.min(lp1, lq))
    lr0 = rlogits(torch.min(lp0, lq))
    ltr1 = ltp(lr1)
    ltr0 = ltp(lr0)
    ds = torch.maximum(lq - lp0, torch.tensor(0.0).to(lq.device)).gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    return kx + (ltr1 - ltr0) + ds

def log_raw_bound_tde(lp1, lp0, lq, labels):
    lp0 = F.log_softmax(lp0.float(), dim=-1)
    lp1 = F.log_softmax(lp1.float(), dim=-1)
    lq = F.log_softmax(lq.float(), dim=-1)
    kx = log_raw_bound(lp1, lq)
    ds = torch.maximum(lq - lp0, torch.tensor(0.0).to(lq.device)).gather(-1, labels.unsqueeze(-1)).squeeze(-1)
    return kx + ds

def log_rel_bound(logits1, logits2):
    logits1 = rlogits(logits1)
    logits2 = rlogits(logits2)
    return torch.mean(torch.abs(logits1 - logits2), dim=-1) / 2.0

def log_raw_bound(logits1, logits2):
    probs1 = torch.softmax(logits1.float(), dim=-1)
    probs2 = torch.softmax(logits2.float(), dim=-1)
    diff = torch.abs(probs1 - probs2)
    mask = torch.isfinite(diff)
    tv = torch.sum(torch.where(mask, diff, torch.zeros_like(diff)),dim=-1) / 2.0
    res = torch.log(1 / (torch.max(1 - tv, torch.tensor(1e-8).to(tv.device))))
    # print("lrb shapes:", diff.shape, mask.shape, diff[mask].shape, tv.shape, res.shape)
    return res

def ltp(logits):
    m = torch.mean(F.log_softmax(logits, dim=-1), dim=-1)
    return m

def rlogits(logits):
    m = torch.mean(logits, dim=-1, keepdim=True)
    return logits - m

def normalize_logits(logits, relative):
    if relative:
        return rlogits(logits)
    else:
        return F.log_softmax(logits, dim=-1)

def kl_divergence(logits_p, logits_q):
    p = torch.softmax(logits_p, dim=-1)
    log_p = F.log_softmax(logits_p, dim=-1)
    log_q = F.log_softmax(logits_q, dim=-1)
    kl_div = torch.sum(p * (log_p - log_q), dim=-1)
    return kl_div

def combine_with_base(logits, base_logits, k=10):
    logits = F.log_softmax(logits, dim=-1)
    logits = rlogits(logits.clamp_min(-20))
    base_logits = F.log_softmax(base_logits, dim=-1)
    base_logits = rlogits(base_logits.clamp_min(-20))
    probs = torch.softmax(logits, dim=-1)

    divs =  probs * (logits - base_logits)
    # compute kth largest div using torch.kthvalue
    kth_largest_div = -torch.kthvalue(-divs, k+1, dim=-1).values.unsqueeze(-1)

    ltp1 = ltp(logits)
    ltp_base = ltp(base_logits)
    probs_base = torch.softmax(base_logits, dim=-1)
    
    sb = torch.sum((divs <= kth_largest_div) * probs_base, dim=-1)
    eps = torch.exp(ltp1 - ltp_base) * sb
    s_rem = torch.sum((divs <= kth_largest_div) * divs, dim=-1)

    bound  = torch.log(1 + eps) + s_rem
    return torch.where(divs > kth_largest_div, rlogits(logits), rlogits(base_logits)), bound

def parse_arguments():
    """
    Parses command-line arguments for model evaluation configuration.
    """
    parser = argparse.ArgumentParser(description="Evaluation script for scp_dr models")
    parser.add_argument("--model_checkpoint1", type=str, help="Path to the first model checkpoint")
    parser.add_argument("--model_checkpoint2", type=str, help="Path to the second model checkpoint")
    parser.add_argument("--model_for_tokenizer", type=str, help="Path to a model checkpoint used for tokenizer only")
    parser.add_argument("--logits1", type=str, help="Path to the logits1 file")
    parser.add_argument("--logits2", type=str, help="Path to the logits2 file")
    parser.add_argument("--logits_base", type=str, help="Path to the logits_base file")
    parser.add_argument("--logits_balance", type=str, help="Path to the logits_balance file")
    parser.add_argument("--balance_checkpoint", type=str, help="Path to the balance model checkpoint for balancing")
    parser.add_argument("--base_model_checkpoint", type=str, help="Path to the base model checkpoint for base smoothing")
    parser.add_argument("--dataset_name", type=str, required=True, default="MathAbstracts", help="Name of the dataset")
    parser.add_argument("--n_test_samples", type=int, default=500, help="Number of test samples")
    parser.add_argument("--sample_start_index", type=int, default=0, help="sample start index")
    parser.add_argument("--output_dir", type=str, default="./eval", help="Directory to save evaluation results")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation")
    parser.add_argument("--k", type=int, default=10, help="Base smoothing parameter k")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output during evaluation")
    parser.add_argument("--use_relative_probs", action="store_true", help="use relative probabilities when aggregating")
    parser.add_argument("--save_const_logit_vec", action="store_true", help="save the average of all logits vectors")
    parser.add_argument("--canaries_path", type=str, help="canaries dataset path for evaluation")
    return parser.parse_args()


def init_tokenizer(model_checkpoint):
    tokenizer = AutoTokenizer.from_pretrained(
        model_checkpoint,
        padding_side="left",
        trust_remote_code=True
    )
    # Do NOT add special tokens if the checkpoint was trained without them
    # tokenizer.add_special_tokens({"sep_token": "[SEP]", "pad_token": "[PAD]"})
    print("Tokenizer vocab size:", len(tokenizer))
    return tokenizer


def load_model(path, tokenizer, is_logits=False, name="model"):
    if not is_logits:
        print(f"Loading {name} from: {path}")
        model = AutoModelForCausalLM.from_pretrained(
            path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
        ).half()
        model.resize_token_embeddings(len(tokenizer))
    else:
        if path == "default":
            print(f"Using default logits for {name}.")
            logits = torch.zeros((len(tokenizer),), dtype=torch.float16).to("cuda")
            print(f"Default logits shape: {logits.shape}")
            return {'train': [logits], 'validation': [logits]}
        path = Path(path)
        assert(path.is_dir())
        print(f"Loading {name} logits from directory: {path}")
        logits_train = []
        logits_validation = []
        for file in sorted(os.listdir(path)):
            if file.endswith("train.pt"):
                print(f"Loading train logits file: {file}")
                logits = torch.load(os.path.join(path, file))
                logits_train.append(logits)
            elif file.endswith("validation.pt"):
                print(f"Loading validation logits file: {file}")
                logits = torch.load(os.path.join(path, file))
                logits_validation.append(logits)
            elif file.endswith("const_logit_vec.pt"):
                print(f"Loading constant logits vector file: {file}")
                logits = torch.load(os.path.join(path, file))
                logits_train.append([logits])
                logits_validation.append([logits])
        assert(len(logits_train) > 0)
        print("len logits:", len(logits_train))
        print("len logits[0]:", len(logits_train[0]))
        assert(len(logits_train) == len(logits_validation))
        if len(logits_train) == 1:
            logits_train = logits_train[0]
            logits_validation = logits_validation[0]
        model = {'train': logits_train, 'validation': logits_validation}
    return model

def load_models(args, tokenizer):
    """
    Loads the primary model or CPModel based on provided arguments.
    """
    model1 = None
    if args.model_checkpoint1:
        model1 = load_model(args.model_checkpoint1, tokenizer, name="model1")
    elif args.logits1:
        model1 = load_model(args.logits1, tokenizer, is_logits=True, name="model1")
    else:
        raise ValueError("Either model_checkpoint1 or logits1 must be provided.")

    model2 = None
    if args.model_checkpoint2:
        model2 = load_model(args.model_checkpoint2, tokenizer, name="model2")
    elif args.logits2:
        model2 = load_model(args.logits2, tokenizer, is_logits=True, name="model2")
    elif not args.model_checkpoint1:
        raise ValueError("Either model_checkpoint2 or logits2 must be provided when model_checkpoint1 is not provided.")

    base_model = None
    if args.base_model_checkpoint:
        base_model = load_model(args.base_model_checkpoint, tokenizer, name="base_model")
    elif args.logits_base:
        base_model = load_model(args.logits_base, tokenizer, is_logits=True, name="base_model")
    
    balance_model = None
    if args.balance_checkpoint:
        balance_model = load_model(args.balance_checkpoint, tokenizer, name="balance_model")
    elif args.logits_balance:
        balance_model = load_model(args.logits_balance, tokenizer, is_logits=True, name="balance_model")
        
    return model1, model2, base_model, balance_model


def logit_compare_stats(logits1, logits2, base_logits, balance_logits, labels, 
                        relative=True, k=10, logits_ste_estimate=None):
    """
    Compute statistics comparing two logit tensors, supporting both single and batch inputs.
    
    Args:
        logits1: Tensor of shape (seq_len,) or (batch_size, seq_len)
        logits2: Tensor of shape (seq_len,) or (batch_size, seq_len)
        relative: If True, normalize logits relative to their mean
    
    Returns:
        dict: Statistics computed across all elements
    """

    logits1_orig = logits1
    logits2_orig = logits2
    logits1 = normalize_logits(logits1, relative)
    logits2 = normalize_logits(logits2, relative)
    if balance_logits is not None:
        balance_logits = normalize_logits(balance_logits, relative)
        q_minus_p0_at_labels = torch.clamp(balance_logits - logits2, min=0).gather(-1, labels.unsqueeze(-1)).squeeze(-1)
        logits1 = torch.min(logits1, balance_logits)
        logits2 = torch.min(logits2, balance_logits)
        logits1 = normalize_logits(logits1, relative)
        logits2 = normalize_logits(logits2, relative)
        alpha0 = torch.min(logits1, balance_logits) - logits2
        alpha1 = torch.min(logits2, balance_logits) - logits1
        alpha = torch.max(alpha0, alpha1)

    loss1 = per_token_log_prob(logits1, labels)
    loss2 = per_token_log_prob(logits2, labels)
    diff_loss = loss1 - loss2
    if base_logits is not None:
        base_logits = normalize_logits(base_logits, relative)
        logits1, bound1 = combine_with_base(logits1, base_logits, k=k)
        logits2, bound2 = combine_with_base(logits2, base_logits, k=k)
        logits1 = normalize_logits(logits1, relative)
        logits2 = normalize_logits(logits2, relative)
        loss_base = per_token_log_prob(base_logits, labels)
    
        loss1 = per_token_log_prob(logits1, labels)
        loss2 = per_token_log_prob(logits2, labels)

        diff_loss = loss1 - loss2

    diff = logits1 - logits2
    
    if logits_ste_estimate is not None:
        t_value = (diff / logits_ste_estimate).float()
    
    p1 = torch.softmax(logits1, dim=-1)
    p2 = torch.softmax(logits2, dim=-1)
    rl1 = rlogits(logits1)
    rl2 = rlogits(logits2)

    min_rel = torch.min(rl1, rl2)
    min_raw = torch.min(p1, p2)
    geo_raw_mean = 0.5*F.log_softmax(logits1, dim=-1) + 0.5*F.log_softmax(logits2, dim=-1)

    chosen1 = torch.argmax(p1, dim=-1)
    diff_at_chosen1 = (normalize_logits(logits1,False) - normalize_logits(logits2,False)).gather(-1, chosen1.unsqueeze(-1)).squeeze(-1)
    lp_at_chosen1 = torch.log(p1.gather(-1, chosen1.unsqueeze(-1)).squeeze(-1))
    lp2_at_chosen1 = torch.log(p2.gather(-1, chosen1.unsqueeze(-1)).squeeze(-1))
    chosen_min_rel = torch.argmax(min_rel, dim=-1)
    chosen_min_raw = torch.argmax(min_raw, dim=-1)
    chosen_geo_raw = torch.argmax(geo_raw_mean, dim=-1)
    lp_chosen_min_rel = torch.log_softmax(min_rel, dim=-1).gather(-1, chosen_min_rel.unsqueeze(-1)).squeeze(-1)
    lp_chosen_min_raw = torch.log_softmax(torch.log(min_raw), dim=-1).gather(-1, chosen_min_raw.unsqueeze(-1)).squeeze(-1)
    lp_chosen_geo_raw = torch.log_softmax(geo_raw_mean, dim=-1).gather(-1, chosen_geo_raw.unsqueeze(-1)).squeeze(-1)

    
    top1_acc_rel = (torch.argmax(min_rel, dim=-1) == labels)
    top5_acc_rel = (torch.topk(min_rel, k=5, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top10_acc_rel = (torch.topk(min_rel, k=10, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top1_acc_raw = (torch.argmax(min_raw, dim=-1) == labels)
    top5_acc_raw = (torch.topk(min_raw, k=5, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top10_acc_raw = (torch.topk(min_raw, k=10, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top1_acc_geo = (torch.argmax(geo_raw_mean, dim=-1) == labels)
    top5_acc_geo = (torch.topk(geo_raw_mean, k=5, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top10_acc_geo = (torch.topk(geo_raw_mean, k=10, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top1_acc_half = (torch.argmax(logits1, dim=-1) == labels)
    top1_acc_half2 = (torch.argmax(logits2, dim=-1) == labels)
    top5_acc_half = (torch.topk(logits1, k=5, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    top10_acc_half = (torch.topk(logits1, k=10, dim=-1).indices == labels.unsqueeze(-1)).any(dim=-1)
    # Convert to float32 for quantile operations (quantile doesn't support float16)
    diff = diff.float()

    stats = {}
    stats["ltp1"] = ltp(logits1).cpu().numpy().flatten()
    stats["ltp2"] = ltp(logits2).cpu().numpy().flatten()
    stats["ltp_min_rel"] = ltp(min_rel).cpu().numpy().flatten()
    stats["ltp_min_rel - ltp1"] = (ltp(min_rel) - ltp(logits1)).cpu().numpy().flatten()
    stats["ltp_min_rel - ltp2"] = (ltp(min_rel) - ltp(logits2)).cpu().numpy().flatten()
    stats["ltp1 - ltp2"] = (ltp(logits1) - ltp(logits2)).cpu().numpy().flatten()
    stats['loss'] = loss1.cpu().numpy().flatten()
    stats['loss_difference'] = diff_loss.cpu().numpy().flatten()
    if base_logits is not None:
        stats['loss_difference_base'] = (loss1 - loss_base).cpu().numpy().flatten()
        stats['utility_bound1'] = bound1.cpu().numpy().flatten()
        stats['utility_bound2'] = bound2.cpu().numpy().flatten()
        stats['utility_kl1'] = kl_divergence(logits1_orig, logits1).cpu().numpy().flatten()
        stats['utility_kl2'] = kl_divergence(logits2_orig, logits2).cpu().numpy().flatten()
    else:
        stats['loss_difference_base'] = torch.zeros_like(loss1).cpu().numpy().flatten()
    stats['diff_at_chosen1'] = diff_at_chosen1.cpu().numpy().flatten()
    stats['lp_at_chosen1'] = lp_at_chosen1.cpu().numpy().flatten()
    stats['lp2_at_chosen1'] = lp2_at_chosen1.cpu().numpy().flatten()
    stats['lp_chosen_min_rel'] = lp_chosen_min_rel.cpu().numpy().flatten()
    stats['lp_chosen_min_raw'] = lp_chosen_min_raw.cpu().numpy().flatten()
    stats['lp_chosen_geo_raw'] = lp_chosen_geo_raw.cpu().numpy().flatten()
    if logits_ste_estimate is not None:
        stats['t_value_mean'] = torch.mean(t_value, dim=-1).cpu().numpy().flatten()
        stats['t_value_std'] = torch.std(t_value, dim=-1).cpu().numpy().flatten()
        stats['t_value_max'] = torch.max(t_value, dim=-1).values.cpu().numpy().flatten()
        stats['t_value_95%'] = torch.quantile(t_value, 0.95, dim=-1).cpu().numpy().flatten()
        stats['t_value_99%'] = torch.quantile(t_value, 0.99, dim=-1).cpu().numpy().flatten()
        stats['t_value_99.9%'] = torch.quantile(t_value, 0.999, dim=-1).cpu().numpy().flatten()
    if balance_logits is not None:
        z0_z1_div_rel = z_div(logits2_orig, logits1_orig, balance_logits, True).cpu().numpy().flatten()
        z0_z1_div_raw = z_div(logits2_orig, logits1_orig, balance_logits, False).cpu().numpy().flatten()
        stats['log(Z0/Z1)_rel'] = z0_z1_div_rel
        stats['log(Z0/Z1)_raw'] = z0_z1_div_raw
        stats['alpha_mean'] = torch.mean(alpha, dim=-1).cpu().numpy().flatten()
        stats['alpha_max'] = torch.max(alpha, dim=-1).values.cpu().numpy().flatten()
        stats['alpha0_max'] = torch.max(alpha0, dim=-1).values.cpu().numpy().flatten()
        stats['alpha1_max'] = torch.max(alpha1, dim=-1).values.cpu().numpy().flatten()
        stats['q_minus_p0_at_labels'] = q_minus_p0_at_labels.cpu().numpy().flatten()
        stats['LogRelBoundTDE'] = log_rel_bound_tde(logits1_orig, logits2_orig, balance_logits, labels).cpu().numpy().flatten()
        stats['LogRawBoundTDE'] = log_raw_bound_tde(logits1_orig, logits2_orig, balance_logits, labels).cpu().numpy().flatten()
    stats['LogRelBound'] = log_rel_bound(logits1, logits2).cpu().numpy().flatten()
    stats['LogRelBound_orig'] = log_rel_bound(logits1_orig, logits2_orig).cpu().numpy().flatten()
    stats['LogRawBound'] = log_raw_bound(logits1, logits2).cpu().numpy().flatten()
    stats['MAE'] = torch.mean(torch.abs(diff), dim=-1).cpu().numpy().flatten()
    stats['mean'] = torch.mean(diff, dim=-1).cpu().numpy().flatten()
    stats['std'] = torch.std(diff, dim=-1).cpu().numpy().flatten()
    stats['max'] = torch.max(diff, dim=-1).values.cpu().numpy().flatten()
    stats['min'] = torch.min(diff, dim=-1).values.cpu().numpy().flatten()
    stats['median'] = torch.median(diff, dim=-1).values.cpu().numpy().flatten()
    stats['top1_acc_rel'] = top1_acc_rel.float().cpu().numpy().flatten()
    stats['top5_acc_rel'] = top5_acc_rel.float().cpu().numpy().flatten()
    stats['top10_acc_rel'] = top10_acc_rel.float().cpu().numpy().flatten()
    stats['top1_acc_raw'] = top1_acc_raw.float().cpu().numpy().flatten()
    stats['top5_acc_raw'] = top5_acc_raw.float().cpu().numpy().flatten()
    stats['top10_acc_raw'] = top10_acc_raw.float().cpu().numpy().flatten()
    stats['top1_acc_geo'] = top1_acc_geo.float().cpu().numpy().flatten()
    stats['top5_acc_geo'] = top5_acc_geo.float().cpu().numpy().flatten()
    stats['top10_acc_geo'] = top10_acc_geo.float().cpu().numpy().flatten()
    stats['top1_acc_half'] = top1_acc_half.float().cpu().numpy().flatten()
    stats['top1_acc_half2'] = top1_acc_half2.float().cpu().numpy().flatten()
    stats['top5_acc_half'] = top5_acc_half.float().cpu().numpy().flatten()
    stats['top10_acc_half'] = top10_acc_half.float().cpu().numpy().flatten()
    # stats['25th_percentile'] = torch.quantile(diff, 0.25, dim=-1).cpu().numpy().flatten()
    # stats['75th_percentile'] = torch.quantile(diff, 0.75, dim=-1).cpu().numpy().flatten()
    # stats['90th_percentile'] = torch.quantile(diff, 0.90, dim=-1).cpu().numpy().flatten()
    # stats['95th_percentile'] = torch.quantile(diff, 0.95, dim=-1).cpu().numpy().flatten()
    # stats['99th_percentile'] = torch.quantile(diff, 0.99, dim=-1).cpu().numpy().flatten()
    # stats['99.9th_percentile'] = torch.quantile(diff, 0.999, dim=-1).cpu().numpy().flatten()
    # stats['99.99th_percentile'] = torch.quantile(diff, 0.9999, dim=-1).cpu().numpy().flatten()
    stats['amount_equal_0'] = (diff == 0).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_01'] = (diff > 0.1).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_03'] = (diff > 0.3).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_05'] = (diff > 0.5).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_1'] = (diff > 1).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_2'] = (diff > 2).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_3'] = (diff > 3).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_5'] = (diff > 5).float().mean(dim=-1).cpu().numpy().flatten()
    stats['amount_greater_than_10'] = (diff > 10).float().mean(dim=-1).cpu().numpy().flatten()

    for k_log in range(14):
        k_exp = int((2 ** k_log))
        B_all = -torch.kthvalue(-torch.abs(diff), k_exp+1, dim=-1).values.flatten()
        stats[f"B_{k_log}"] = B_all.cpu().numpy().flatten()

    # thresholds = [0, 0.01, 0.03, 0.1, 0.3, 1, 3, 10, 30, 100]
    # for t in thresholds:
    #     stats[f'diff2prcnt_{t}'] = (torch.abs(diff) < t).float().mean(dim=-1).cpu().numpy().flatten()

    return stats

def per_token_log_prob(logits, labels, ignore_index=-100):
    # Flatten batch and seq_len for CrossEntropyLoss
    loss_fct = CrossEntropyLoss(reduction="none", ignore_index=ignore_index)
    flat_loss = loss_fct(
        logits.view(-1, logits.size(-1)),
        labels.view(-1)
    )
    loss_per_token = flat_loss.view(labels.size())
    loss_per_token = loss_per_token.masked_fill(labels == ignore_index, float("nan"))
    
    return -loss_per_token

def get_logits_from_model(model, batch_tensors, counter, logits1=None):
    outputs = None
    logits = None
    logits_s = None
    if model is not None:
        if type(model) == list:
            if len(model) == 1 and logits1 is not None:
                if type(model[0]) == list:
                    raise ValueError("len(model) == 1 but model[0] is a list")
                logits = model[0].expand_as(logits1)
            elif type(model[0]) == list:
                res = torch.zeros_like(model[0][counter-1])
                res2 = torch.zeros_like(model[0][counter-1])
                for i in range(len(model)):
                    res += model[i][counter - 1]
                    res2 += model[i][counter - 1]**2
                logits = res / len(model)
                logits_s = (res2 - (res**2 / len(model)))
            else:
                logits = model[counter - 1]
        else:
            outputs = model(**batch_tensors, return_dict=True)
            logits = outputs.logits
    return outputs, logits, logits_s

def compute_lc_statistics(model1, model2, base_model, balance_model, data, tokenizer, batch=16, 
                         relative=True, k=10):
    # data = data.map(split_text)
    dataloader = DataLoader(
        data,
        batch_size=batch,
        shuffle=False,
        pin_memory=True,
    )
    
    with torch.no_grad():
        counter = 1
        all_stats = []
        all_logits = []
        for batch in tqdm(dataloader):

            batch_texts = batch["text"]
            batch_tensors = tokenizer(batch_texts, return_tensors="pt", padding=True).to("cuda")
            # Prepare labels (mask padding tokens)
            labels = torch.roll(batch_tensors["input_ids"].clone(), shifts=-1, dims=1)
            labels[:, -1] = tokenizer.eos_token_id
            if tokenizer.pad_token_id is not None:
                labels[labels == tokenizer.pad_token_id] = -100  # ignore padding in loss

            # Get logits and loss from both models
            outputs1, logits1, logits1_s = get_logits_from_model(model1, batch_tensors, counter)
            outputs2, logits2, logits2_s = get_logits_from_model(model2, batch_tensors, counter)
            base_outputs, base_logits, _ = get_logits_from_model(base_model, batch_tensors, counter, logits1=logits1)
            balance_outputs, balance_logits, _ = get_logits_from_model(balance_model, batch_tensors, counter)

            logits_ste_estimate = None
            if logits1_s is not None and logits2_s is not None:
                degrees_freedom = len(model1) + len(model2) - 2
                logits_ste_estimate = (((logits1_s + logits2_s) / degrees_freedom) * (1. / len(model1) + 1. / len(model2))) ** 0.5
          
            if logits2 is not None:
                all_stats.append(logit_compare_stats(logits1, logits2, base_logits, balance_logits, labels, 
                                relative=relative, k=k, logits_ste_estimate=logits_ste_estimate))
            else:
                # save logits1 only
                if base_logits is not None:
                    logits1,_ = combine_with_base(logits1, base_logits, k=k)
                    logits1 = normalize_logits(logits1, relative)
                all_logits.append(logits1)
            del (
                logits1,
                logits2,
                base_logits,
                balance_logits,
                outputs1,
                outputs2,
                base_outputs,
                balance_outputs,
                batch_tensors,
                batch_texts,
            )
            torch.cuda.empty_cache()
            gc.collect()
            counter = counter + 1
    # combine all stats into a dataframe
    # all_stats is a list of dicts with each dict containing stats for a batch
    df = None
    if len(all_stats) > 0:
        combined_stats = {}
        for stat_name in tqdm(all_stats[0].keys()):
            combined_stats[stat_name] = []
            for batch_stats in all_stats:
                combined_stats[stat_name].extend(batch_stats[stat_name])
        for k,v in combined_stats.items():
            print(f"{k}: {len(v)}")
        df = pd.DataFrame(combined_stats)
    if len(all_logits) == 0:
        all_logits = None
    return df, all_logits


def evaluate_datasets(train_dataset, validation, model1, model2, base_model, balance_model, 
                      tokenizer, eval_dir, batch_size, relative=True, k=10, save_const_logit_vec=False):
    """
    Evaluates multiple datasets and saves results to CSV files.
    """
    eval_folder_names = []
    datasets = {"train": train_dataset, "validation": validation}

    for name, data in datasets.items():
        file_name = os.path.join(eval_dir, f"LC_{name}.csv")
        
        print(f"Evaluating {name} set...")
        start_time = time.time()
        if type(model1) == dict:
            curr_model1 = model1[name]
        else:
            curr_model1 = model1
        if type(model2) == dict:
            curr_model2 = model2[name]
        else:
            curr_model2 = model2
        if type(base_model) == dict:
            curr_base_model = base_model[name]
        else:
            curr_base_model = base_model
        if type(balance_model) == dict:
            curr_balance_model = balance_model[name]
        else:
            curr_balance_model = balance_model
        eval_res, logits = compute_lc_statistics(model1=curr_model1, model2=curr_model2, base_model=curr_base_model, 
                                                 balance_model=curr_balance_model, data=data, tokenizer=tokenizer,
                                                 batch=batch_size, relative=relative, k=k)
        assert(eval_res is None or logits is None)
        assert(not (eval_res is None and logits is None))
        if eval_res is not None:
            eval_res.to_csv(file_name)
            main_res = eval_res[['LogRelBound', 'LogRawBound', 'MAE', 'max', 'loss', 'loss_difference', 'loss_difference_base', 
                                 'top1_acc_rel', 'top1_acc_raw', 'amount_equal_0', 'amount_greater_than_01', 'amount_greater_than_1']].describe(percentiles=[0.999])
            main_res.to_csv(os.path.join(eval_dir, f"LC_{name}_summary.csv"))
            pd.set_option('display.max_columns', None)
            print(eval_res.describe(percentiles=[0.5, 0.75, 0.9, 0.95, 0.99, 0.999, 0.9999]))
            k2B = {}
            for c in [50, 75, 90, 95, 99, 99.9]:
                for k_log in range(14):
                    B_all = torch.tensor(eval_res[f"B_{k_log}"])
                    B = torch.kthvalue(B_all, int(c / 100.0 * B_all.size(-1)), dim=-1).values.item()
                    if c not in k2B:
                        k2B[c] = {}
                    k2B[c][k_log] = B
            # save k2B to json without pd
            
            with open(os.path.join(eval_dir, f"LC_{name}_k2B.json"), 'w') as f:
                json.dump(k2B, f)
            print("k2B saved to json.")
            
            # plot k2B curves
            
            plt.figure()
            for c in k2B.keys():
                ks = sorted(k2B[c].keys())
                Bs = [k2B[c][k_log] for k_log in ks]
                plt.plot(ks, Bs, label=f"C = {c}%")
            plt.xlabel("log2(k)")
            plt.ylabel("B value")
            plt.title(f"B vs log2(k) for the {name} set")
            plt.legend()
            plt.grid()
            plt.savefig(os.path.join(eval_dir, f"LC_{name}_k2B_curves.png"))
            plt.close()
            print("k2B curves plot saved.")

            plot_histogram(
                eval_res,
                ['LogRawBound', 'LogRelBound'],
                eval_dir,
                name,
                "blue",
                "$k_x$ histograms",
                "$k_x$",
                "Frequency",
                labels=[r'$\Delta_m$', r'$\Delta_r$']
            )
            plot_histogram(
                eval_res,
                ['amount_greater_than_01', 'amount_greater_than_1'],
                eval_dir,
                name,
                "blue",
                "> histograms",
                "Fraction",
                "Frequency",
                labels=['> 0.1', '> 1']
            )
            print(f"Evaluation of {name} completed in {time.time() - start_time:.2f} seconds.")
        if logits is not None:
            save_path = eval_dir + f"_{name}.pt"
            if save_const_logit_vec:
                const_logit_vec = torch.mean(logits[0],dim=-2, keepdim=True)
                for logit in logits[1:]:
                    const_logit_vec += torch.mean(logit,dim=-2, keepdim=True)
                const_logit_vec /= len(logits)
                logits = const_logit_vec
                save_path = eval_dir + f"_const_logit_vec.pt"
            else:
                save_path = eval_dir + f"_{name}.pt"
            torch.save(logits, save_path)
            print(f"Logits saved to {save_path}")
            print(f"Time to save logits: {time.time() - start_time:.2f} seconds.")
        del eval_res
        del logits
        torch.cuda.empty_cache()
        gc.collect()
        eval_folder_names.append(f"{name}")

    # plot_and_analyze_data(eval_dir, eval_folder_names)


def main():
    args = parse_arguments()
    print("LC evaluation configuration:", args)

    # Initialize tokenizer
    if args.model_for_tokenizer:
        tokenizer = init_tokenizer(args.model_for_tokenizer)
    else:
        tokenizer = init_tokenizer(args.model_checkpoint1)

    # Load datasets
    if args.canaries_path is None:
        dataloader = DatasetLoader()
        _, train_dataset, validation_dataset = dataloader.load_or_create_datasets(
            dataset_name=args.dataset_name,
            ntrain=args.n_test_samples,
            k=1,
            start_index=args.sample_start_index,
        )
    else:
        print("Using canaries dataset for evaluation.")
        train_dataset = DatasetDict.load_from_disk(os.path.join(args.canaries_path, "canaries_datasets"))['3']
        validation_dataset = Dataset.from_list(get_canaries(100, can_len=3, seed=27412586))

    # Load model(s)
    model1, model2, base_model, balance_model = load_models(args, tokenizer)

    if base_model is not None and balance_model is not None:
        print("Both base model and balance model provided. Only one should be used. Exiting.")
        return
    if model2 is None:
        print("Model2 not provided. Calculating and saving logits for Model1")
        model_name = os.path.basename(args.model_checkpoint1.rstrip('/'))
        eval_dir = os.path.join(args.output_dir, f"{args.dataset_name}_logits_{model_name}")
        os.makedirs(args.output_dir, exist_ok=True)
    else:
        prob_type = "relative" if args.use_relative_probs else "raw" 
        eval_dir = os.path.join(args.output_dir, f"{args.dataset_name}_{prob_type}_lc_evaluation")
        os.makedirs(eval_dir, exist_ok=True)
    # Run evaluation
    evaluate_datasets(
        train_dataset=train_dataset,
        validation=validation_dataset,
        model1=model1,
        model2=model2,
        base_model=base_model,
        balance_model=balance_model,
        tokenizer=tokenizer,
        eval_dir=eval_dir,
        batch_size=args.batch_size,
        relative=args.use_relative_probs,
        k=args.k,
        save_const_logit_vec=args.save_const_logit_vec,
    )


if __name__ == "__main__":
    main()
