import os

import torch
import wandb
from tqdm import tqdm

from chip.datasets.superres_dataset import SuperresolutionDS
from chip.models.unet import build_unet
from argparse import ArgumentParser
from chip.utils.metrics import b_RMSE, b_PSNR


def evaluate(model, device, dataLoader, num_iterations=-1):
    model.eval()
    model = model.to(device)
    rmse = []
    psnr = []

    with torch.no_grad():
        for i, (image, label, _) in tqdm(enumerate(dataLoader)):
            image = image.to(device)
            label = label.to(device)
            output = model(image.unsqueeze(1).to(device))

            prediction = torch.argmax(output, dim=1)

            rmse.append(b_RMSE(prediction, label))
            psnr.append(b_PSNR(prediction, label))

            if i == num_iterations:
                break

    rmse = torch.cat(rmse)
    psnr = torch.cat(psnr)
    return torch.mean(rmse), torch.mean(psnr)

if __name__ == '__main__':
    file_path = __file__[:__file__.rindex("/") + 1]
    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--model_path", type=str, help="Path with unet checkpoint")
    parser.add_argument("--with_wandb", action='store_true', help="Use or not wandb to log metric")
    args = vars(parser.parse_args())

    experiment_name = "unet_performance"
    if args['with_wandb']:
        wandb.init(project=experiment_name, name=f"{experiment_name}_{args['model_path']}", config=args)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')

    # Using first the forward model for the prior
    unet_model = build_unet(1, num_classes=2).to(device)

    checkpoint = torch.load(f"{file_path}../../{args['model_path']}", map_location=torch.device('cpu'))
    unet_model.load_state_dict(checkpoint['model_state_dict'])
    print(f"model loaded from checkpoint {args['model_path']}")

    files = os.listdir(f"{file_path}../../data/imgs_synthetic")
    ds = SuperresolutionDS(files, data_path=f"{file_path}../../data")

    trainSet = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
    testSet = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))

    dataLoader = torch.utils.data.DataLoader(testSet, batch_size=20, shuffle=False)

    rmse, psnr = evaluate(unet_model, device, dataLoader)
    print("Average RMSE on test dataset", rmse.item())
    print("Average PSNR on test dataset", psnr.item())
    if wandb.run:
        wandb.run.summary["RMSE"] = rmse
        wandb.run.summary["PSNR"] = psnr

