import re
from pathlib import Path

TABLE_LINE_PATTERN = re.compile(r"([a-zA-Z0-9_.]+) +([0-9, ()]+) +([0-9.]+[A-Z])\n")

DISCARD_NAME_PATTERN = re.compile(r"((Parameter)|(Tensor))[0-9]+")

KILO = 1000  # for easy switch b/w 1000 and 1024

def prettify_bytes(memory):
    memory = float(memory)
    UNITS = ["B", "K", "M", "G"]
    current_unit = 0
    while memory >= KILO and current_unit < len(UNITS):
        memory /= KILO
        current_unit += 1
    return f"{memory:.2f}{UNITS[current_unit]}"

def analyze_report(fname, print_name, *, tensor_only_mode=False):
    total_memory = 0
    total_grad_memory = 0
    with open(fname, 'r') as f:
        for line in f:
            m = re.match(TABLE_LINE_PATTERN, line)
            if m:
                name = m.group(1)
                memory = m.group(3)
                if (    (not tensor_only_mode and re.match(DISCARD_NAME_PATTERN, name) is None) or
                        (tensor_only_mode and "Tensor" in name)    ):
                    memory_unit = memory[-1]
                    memory_value = float(memory[:-1])
                    if memory_unit == "B":
                        pass
                    elif memory_unit == "K":
                        memory_value *= KILO
                    elif memory_unit == "M":
                        memory_value *= KILO * KILO
                    elif memory_unit == "G":
                        memory_value *= KILO * KILO * KILO
                    else:
                        raise RuntimeError(f"Unrecognized unit: {memory_unit}")
                    total_memory += memory_value
                    if name.endswith(".grad"):
                        total_grad_memory += memory_value
    print(f"{print_name}: {total_memory} ({prettify_bytes(total_memory)})")
    print(f"    di cui grad: {total_grad_memory} ({prettify_bytes(total_grad_memory)})")

if __name__ == "__main__":
    print("INITIAL")
    ci = Path(__file__).resolve().parent.parent / "outputs/2024-09-05/13-49-10/memrep_crosscoder_initial.txt"
    ei = Path(__file__).resolve().parent.parent / "outputs/2024-09-05/13-49-10/memrep_encoder_initial.txt"
    di = Path(__file__).resolve().parent.parent / "outputs/2024-09-05/13-49-10/memrep_decoder_initial.txt"
    analyze_report(ci, "crosscoder")
    analyze_report(ei, "encoder")
    analyze_report(di, "decoder")

    print("FINAL")
    c2 = Path(__file__).resolve().parent.parent / "outputs/2024-09-05/13-49-10/memrep_crosscoder_ep2.txt"
    e2 = Path(__file__).resolve().parent.parent / "outputs/2024-09-05/13-49-10/memrep_encoder_ep2.txt"
    d2 = Path(__file__).resolve().parent.parent / "outputs/2024-09-05/13-49-10/memrep_decoder_ep2.txt"
    analyze_report(c2, "crosscoder")
    analyze_report(e2, "encoder")
    analyze_report(d2, "decoder")

    print("FINAL TENSORS")
    analyze_report(c2, "crosscoder", tensor_only_mode=True)
    analyze_report(e2, "encoder", tensor_only_mode=True)
    analyze_report(d2, "decoder", tensor_only_mode=True)
