import torch
import pprint

from quant.datasets import SequenceDataset
from quant.models.cnn import Autoencoder
from quant.fsq import fsq_level_book


soi_root_dir = "path_to_qpsk_train"
signal_length = 40960
run_id = 28
ckpt_path = f"runs/quant_cnn{run_id:04}/model_best.ckpt"
device = "cuda"
model_config = dict(
    patch_channels=8,
    channels=[64, 128, 256, 256],
    resnet_count=2,
    use_fsq=True,
    fsq_levels = fsq_level_book[10],
)


if __name__ == "__main__":
    print("Testing checkpoint, path:", ckpt_path)
    print("Model config:")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(model_config)

    dataset = SequenceDataset(root_dir=soi_root_dir, signal_length=signal_length)
    model = Autoencoder(**model_config).to(device)
    state_dict = torch.load(ckpt_path)
    model.load_state_dict(state_dict)
    inp = dataset[0][None, :].to(device)
    enc = model.encode(inp)
    out = model.decode(enc)
    print("Input shape", inp.shape)
    print("Encoding shape:", enc.shape)
    print("Output shape:", out.shape)

    print("Input prefix:", inp[0, :16])
    print("Encoding prefix:", enc[0, :16])
    print("Output prefix:", out[0, :16])
    print("Error on first item:", ((inp - out) ** 2).mean().item())
