import torch
import torch.nn as nn
import pprint

from quant.datasets import MixtureDataset
from quant.models.cnn import Autoencoder
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils.utils import get_train_val_dataset
from quant.utils import RelativeMSELoss

soi_root_dir = "path_to_soi"
interference_root_dir = "path_to_interference"
signal_length = 2560
run_id = 6
batch_size = 8
ckpt_path = f"runs/quant_cnn{run_id:04}/model_best.ckpt"
device = "cuda"
train_fraction = 0.95
loss_func = "RMSE"
model_config = dict(
    patch_channels=8,
    channels=[128, 256, 256],
    resnet_count=2,
    use_fsq=True,
    use_norm=True,
    fsq_bits=16,
    num_transformer_blocks=1,
)

def get_dataloader(dataset, batch_size):
    return DataLoader(dataset,
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=4,
                      pin_memory=True)


def run_validation(model, dataloader, criterion):
    print("Running validation")
    tot_loss = 0.0
    batch_count = 0
    model.eval()
    i = 0
    for batch in tqdm(dataloader):
        i += 1
        batch = batch.to(device)
        out = model(batch)
        loss = criterion(out, batch)
        tot_loss += loss.item()
        batch_count += 1
    return tot_loss / batch_count

if __name__ == "__main__":
    print("Testing checkpoint, path:", ckpt_path)
    print("Model config:")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(model_config)


    model = Autoencoder(**model_config).to(device)
    state_dict = torch.load(ckpt_path)
    model.load_state_dict(state_dict)
    for noise in range(-33, 3, 3):
        _, val_dataset = get_train_val_dataset(MixtureDataset(soi_dir=soi_root_dir,
                                                              interference_dir=interference_root_dir,
                                                              signal_length=signal_length,
                                                              use_rand_phase=True,
                                                              noise_L=noise,
                                                              noise_R=noise),
                                               train_fraction)
        val_loader = get_dataloader(val_dataset, batch_size)
        criterion = nn.MSELoss() if loss_func == "MSE" else RelativeMSELoss()
        val_loss = run_validation(model, val_loader, criterion)
        print("Val loss: ", noise, val_loss)

