from chip.evaluation.benchmark import get_default_argument_parser, active_learning_benchmark, log_iteration
from chip.training.iterative_reconstruction import finetune_sinogram_consistency

import torch
from chip.utils.plotting import plot_comparison
from chip.utils.sinogram import compute_sinogram, Sinogram
from chip.utils.utils import save_image_and_log
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur

from tqdm import tqdm
import wandb

from chip.evaluation.active_learning import uniform, gradient_norm_squared, choose_top_k
from chip.models.iterative_model import TomographicReconstruction


def get_iterative_reconstruction(tr_prior, lr_image, sinogram, steps=None, lr=None):
    if steps is None:
        steps = [2000, 2000]
    if lr is None:
        lr = [0.1, 0.01]

    tr = TomographicReconstruction(prior=tr_prior, use_sigmoid=True)

    loss = finetune_sinogram_consistency(
        tr,
        target_sinogram_hr=sinogram,
        verbose=False,
        batch_size_lr=10, lr=lr, steps=steps, lr_image=lr_image,
    )

    return tr, loss

def run_active_learning(prior, target, config):
    # available angles
    num_angles = config['experiment.theta']
    theta = torch.linspace(0, 180 * (1 - 1 / num_angles), num_angles, device=target.device)

    run = wandb.run

    hr_sinogram = Sinogram()
    selected_indices = []

    # set up model
    tr_prior = prior
    lr_image = prior
    if config['model.prior'] == 'uniform':
        prior = torch.ones_like(target) * 0.5
        lr_image = None

    tr = TomographicReconstruction(prior=prior, use_sigmoid=True)

    # high_res_filter = None
    # lr_angles = None

    for i in tqdm(range(config['experiment.num_iterations']), desc="Iteration", position=1, leave=False):

        if config['acquisition.score'] == 'uniform':
            top_k = [uniform(len(theta), len(selected_indices))]

        elif config['acquisition.score'] == 'mle_gradient_norm':
            scores = gradient_norm_squared(tr, tr.get_mle_tr(), theta)

            # select angles
            top_k = choose_top_k(scores, theta,
                                 k=config['experiment.batch_size'],
                                 radius=config['acquisition.radius'],
                                 exclude_indices=None if config[
                                     'acquisition.allow_duplicates'] else selected_indices)
        else:
            raise ValueError(f"Unknown acquisition score {config['acquisition.score']}")

        # do not re-evaluate angles
        top_k = [i for i in top_k if i not in selected_indices]

        top_k_angles = theta[top_k]
        obs = compute_sinogram(target, top_k_angles)
        hr_sinogram.add(top_k_angles, obs)
        selected_indices += top_k


        tr, loss = get_iterative_reconstruction(tr_prior, lr_image, hr_sinogram)
        tr_prior = pred_image = tr.get_img().detach()

        # log standard metrics
        run.log({"training_loss": loss}, commit=False)
        if config['wandb.image_logging']:
            # log direct comparison
            fig, axes = plot_comparison(pred_image.cpu().detach(), target.cpu().detach(), prior.cpu().detach())
            file = f"tmp/comparison_{run.id}.png"
            save_image_and_log(file, "comparison", commit=False, delete=True)
        log_iteration(top_k_angles, pred_image, target, commit=True)


def main():
    # get argument parser with default arguments
    parser = get_default_argument_parser()

    # model parameters
    parser.add_argument('--model.prior', type=str, choices=['uniform', 'lr'])
    parser.add_argument('--model.sigma', type=float, default=8)
    parser.add_argument('--model.kernel_size', type=int, default=21)

    # algorithm parameters
    parser.add_argument('--acquisition.score', type=str, choices=['uniform', 'mle_gradient_norm'])
    parser.add_argument('--acquisition.allow_duplicates', action='store_true')
    parser.add_argument('--acquisition.radius', action='store_true')

    parser.add_argument("--wandb.image_logging", action='store_true', help="Log images during the run")

    config = vars(parser.parse_args())

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    for prior, target in active_learning_benchmark(config):
        run_active_learning(prior.to(device), target.to(device), config)

if __name__ == '__main__':
    main()
