import torch
from tabulate import tabulate

display_names = {
    'rope': 'RoPE',
    'nope': 'NoPE',
    'p_rope': '$p$-RoPE',
    'ntk_aware_rope': 'RoPE+NTK',
    'logn_and_p_rope_special_s': 'LogN+$p$-RoPE',
    'logn_and_rope_special_s': 'LogN+RoPE',
    'logn_and_ntk_special_s': 'LogN+NTK',
    'alibi': 'ALiBi',
    'nousyarn': 'YaRN',
    'scale_invariant_p_rope_no_qk_norm': 'Scale-invariant $p$-RoPE (ours)'
}

def get_location(method):
    return f"torchtune/llama2_7B/{method}/metrics_{method}.pt"

def load_metrics():
    metrics = dict()
    for method in ['rope', 'p_rope','nope', 'ntk_aware_rope', 'nousyarn', 'logn_and_rope_special_s', 'logn_and_p_rope_special_s', 'logn_and_ntk_special_s', 'alibi', 'scale_invariant_p_rope_no_qk_norm']:
        metrics[method] = torch.load(get_location(method))
    return metrics

def get_latest_metrics(metrics):
    # Create table data
    table_data = []
    headers = ['Method', 'Val @ 4k', 'Val @ 16k', 'Val @ 64k']

    for method in metrics:
        # If there's only one entry, use it directly
        if len(metrics[method]) == 1:
            entry = metrics[method][0]
        else:
            # Otherwise find the entry for step 125
            entry = next((x for x in metrics[method] if x.get('step') == 125), None)

        def maybe_add_dagger(x): return f"{x}$^*$" if method in ['rope', 'ntk_aware_rope'] else x

        if entry:
            row = [
                maybe_add_dagger(display_names[method]),
                f"{entry.get('val_loss_4096', 'N/A'):.3f}" if 'val_loss_4096' in entry else 'N/A',
                f"{entry.get('val_loss_16384', 'N/A'):.3f}" if 'val_loss_16384' in entry else 'N/A',
                f"{entry.get('val_loss_65536', 'N/A'):.3f}" if 'val_loss_65536' in entry else 'N/A'
            ]
            table_data.append(row)

    return tabulate(table_data, headers=headers, tablefmt='latex_raw', disable_numparse=True)

if __name__ == "__main__":
    metrics = load_metrics()
    print("\nValidation Metrics:")
    print(get_latest_metrics(metrics))