import os

import wandb
from chip.models.forward_models import fourier_filtering
from chip.models.iterative_model import TomographicReconstruction
import torch
from tqdm import tqdm
import numpy as np

from chip.datasets.superres_dataset import SuperresolutionDS
from chip.models.unet import build_unet
from argparse import ArgumentParser

from chip.utils import get_uniform_angles, create_circle_filter, create_gaussian_filter
from chip.utils.metrics import get_metrics
from chip.training.iterative_reconstruction import finetune_sinogram_consistency

from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur

from chip.utils.sinogram import Sinogram, compute_sinogram

if __name__ == '__main__':
    file_path = __file__[:__file__.rindex("/")+1]
    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--model_path", type=str, help="Path with unet checkpoint")
    parser.add_argument("--with_wandb", action='store_true', help="Use or not wandb to log metric")
    parser.add_argument("--num_hr_angles", type=int, default=10, help="Number of hr directions to be used")
    parser.add_argument("--batch_size_lr", type=int, default=10, help="Number of lr random directions to be used in each SGD step")
    args = vars(parser.parse_args())

    experiment_name = "iter_recons_performance"
    if args['with_wandb']:
        wandb.init(project=experiment_name, name=f"{experiment_name}_{args['model_path']}", config=args)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps')
    recons_device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Using first the forward model for the prior
    unet_model = build_unet(1, num_classes=2).to(device)

    checkpoint = torch.load(f"{file_path}../../{args['model_path']}", map_location=torch.device('cpu'))
    unet_model.load_state_dict(checkpoint['model_state_dict'])
    print(f"model loaded from checkpoint {args['model_path']}")

    files = os.listdir(f"{file_path}../../data/imgs_synthetic")
    ds = SuperresolutionDS(files, data_path=f"{file_path}../../data")

    trainSet = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
    testSet = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))


    metrics = ["PSNR", "RMSE", "ZeroOne", "SS"]
    total_metrics = {key:[] for key in metrics}
    iterator = tqdm(testSet)
    for source, target, _ in iterator:
        output = unet_model(source.unsqueeze(0).unsqueeze(0).to(device))
        prior = torch.argmax(output[0], dim=0).detach().to(recons_device).float()

        tr = TomographicReconstruction(prior=prior, use_sigmoid=True)

        circle_filter = create_circle_filter(30, target.shape[-1])
        gaussian_filter = create_gaussian_filter(sigma=15, size=target.shape[-1])

        hr_angles = torch.tensor([0, 90, 1, 89, 179, 91]).float()
        hr_sinogram = Sinogram(compute_sinogram(target, hr_angles), hr_angles)

        lr_angles = get_uniform_angles()
        lr_sinogram = Sinogram(compute_sinogram(source, lr_angles), lr_angles)

        loss = finetune_sinogram_consistency(
            tr,
            target_sinogram_hr=hr_sinogram,
            target_sinogram_lr=lr_sinogram,
            lr_forward_function=lambda x: fourier_filtering(x, circle_filter),
            batch_size=20,
            verbose=True, steps=[200, 200, 200], lr=[0.5, 0.1, 0.01]
        )

        metric_dict = get_metrics(target, tr.get_img().detach().to('cpu'))
        for key in metric_dict:
            total_metrics[key].append(metric_dict[key])
        iterator.set_postfix(metric_dict)

    for key in total_metrics:
        print(f"Average {key} on test dataset", sum(total_metrics[key]) / len(total_metrics[key]))
        if wandb.run:
            wandb.run.summary[key] = sum(total_metrics[key]) / len(total_metrics[key])



