import argparse
import os
import time
import numpy as np
import wandb
import gc
import itertools

from chip.evaluation.active_learning import mean_gradient_norm, choose_top_k
from chip.evaluation.benchmark import active_learning_benchmark, log_iteration, get_default_argument_parser
from chip.evaluation.tr_active_learning import get_iterative_reconstruction
from chip.models.forward_models import fourier_filtering
from chip.utils.diffusion import get_diff_unet_model, get_diffusion_samples
from chip.utils.laplace import SubsetRayTrafo, get_dip_mask, get_posterior, get_unet_reconstructor, optimize_unet, sample_unet, train_unet
from chip.utils import create_gaussian_filter
from chip.utils.metrics import get_metrics
from chip.utils.plotting import plot_sinogram, plot_comparison
from chip.utils.sinogram import batched_sinogram, compute_sinogram, Sinogram, sklearn_fbp
from chip.utils.tracker import Tracker, FlaggedArgumentParser

import lovely_tensors as lt
lt.monkey_patch()

import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur

from diffusers import UNet2DModel
from diffusers import DDPMScheduler

from chip.models.tomographic_diffusion import TomographicDiffusion
from chip.models.iterative_model import TomographicReconstruction

from chip.training.iterative_reconstruction import sample_ensemble, sample_swag, total_variation_loss, finetune_sinogram_consistency
from chip.utils.utils import create_circle_filter, get_uniform_angles, load_model, save_image_and_log
from chip.training.train_diffusion_tomogram import get_dataset

import warnings

# Ignore UserWarning
warnings.filterwarnings('ignore', category=UserWarning)


def uniform_scores(num):
    intervals = [(1, num-1)]

    result = [0]
    while len(intervals) > 0:
        start, stop = intervals.pop(0)
        mid = (start + stop) // 2
        result.append(mid)

        if start != mid:
            intervals.append((start, mid-1))
        if mid != stop:
            intervals.append((mid+1, stop))

    scores = np.zeros(num)
    for i in range(num):
        scores[result[i]] = i

    return scores

def sampled_expected_info_gain(samples, noise_obs_std: float):
    
    with torch.no_grad():
        samples_per_detector_per_angle = samples.swapaxes(0,1).swapaxes(1,2)
        # y_samples_per_detector_per_angle -d> (angles, detectors, samples)
        mc_samples = samples_per_detector_per_angle.shape[-1]
        angle_cov = torch.bmm(
                samples_per_detector_per_angle, samples_per_detector_per_angle.transpose(1,2)
            ) / mc_samples # (angles, detectors, detectors)
        angle_cov += (noise_obs_std ** 2) * torch.eye(angle_cov.shape[1], device=angle_cov.device)[None, :, :]
        s, eig = torch.linalg.slogdet(angle_cov)
        assert all(s == 1)

    return eig

def sampled_var(zero_mean_samples, **kwargs):
    with torch.no_grad():
        return torch.mean( zero_mean_samples.pow(2), dim=(0, -1) ).squeeze(0)

def filter_sinogram(sinogram: Sinogram, indices):
    return Sinogram(sinogram.sinogram[indices], sinogram.angles[indices])


def get_mle_sample(samples, hr_sinogram : Sinogram, lr_sinogram=None, lr_forward_function=None):
    """
    Among a set of samples, choose the one with the smallest prediciton error.
    """
    errors = torch.zeros(samples.shape[0], device=samples.device)

    # add error for fitting the hr sinogram
    if hr_sinogram is not None and len(hr_sinogram) > 0:
        pred_singorams = batched_sinogram(samples, sinogram_angles=hr_sinogram.angles)
        errors += torch.sum((hr_sinogram.sinogram - pred_singorams) ** 2, dim=(1, 2))

    if lr_sinogram is not None and len(lr_sinogram) > 0:
        pred_singorams = batched_sinogram(torch.func.vmap(lr_forward_function)(samples), sinogram_angles=lr_sinogram.angles)
        errors += torch.sum((lr_sinogram.sinogram - pred_singorams) ** 2, dim=(1, 2))
    return samples[torch.argmin(errors)]


# def get_reference(samples, hr_sinogram=None, reference_fct='mean'):
#     if reference_fct == 'mean':
#         return torch.mean(samples, dim=0)
#     elif reference_fct == 'mle':
#         return get_mle_sample(samples, hr_sinogram)
    

class GenerativeModel:
    def __init__(self, config, device, filter_fct, name):
        self.config = config
        self.name = name
        self.device = device
        self.filter_fct = filter_fct.to(device)
        self.lr_forward_function = lambda x: fourier_filtering(x, filter_fct.to(device))
        self.im_size = config['dataset.im_size']
        self.im_shape = (self.im_size, self.im_size)

    def train(self, hr_sinogram=None, lr_sinogram=None):
        raise NotImplementedError

    def predict(self, num_samples=20):
        raise NotImplementedError


class Diffusion(GenerativeModel):
    def __init__(self, config, device, filter_fct):
        super().__init__(config, device, filter_fct, name='Diffusion')
        path = config['model.path']
        if os.path.isdir(path):
            path = os.path.join(path, f"ddpm_DDPM_{config['dataset']}_Tomograms_{self.im_size}.pt")
        self.model = get_diff_unet_model(path, im_size=self.im_size, device=device)
        self.model.eval()
        self.hr_sinogram = None
        self.lr_sinogram = None

    def train(self, hr_sinogram, lr_sinogram):
        self.hr_sinogram = hr_sinogram
        self.lr_sinogram = lr_sinogram

    def predict(self, num_samples=20):
        samples = get_diffusion_samples(self.model, 
                                     hr_sinogram = self.hr_sinogram,
                                     lr_sinogram = self.lr_sinogram,
                                     lr_forward_function=self.lr_forward_function, 
                                     num_samples=num_samples,
                                     batch_size=self.config['model.diffusion.batch_size'],
                                     buffer=self.config['model.diffusion.buffer'],
                                     use_sigmoid=self.config['model.diffusion.use_sigmoid'],
                                     device=self.device,
                                     ).squeeze()
        reconstruction = get_mle_sample(samples, 
                                   hr_sinogram=self.hr_sinogram, 
                                   lr_sinogram=self.lr_sinogram, 
                                   lr_forward_function=self.lr_forward_function)
        return reconstruction, samples
    

class IterativeBaseModel(GenerativeModel):
    def __init__(self, config, device, filter_fct, name):
        super().__init__(config, device, filter_fct, name=name)
        self.model = TomographicReconstruction(prior=torch.zeros(self.im_shape, device=self.device), use_sigmoid=True)
        self.model.eval()
        self.lr_sinogram = None
        self.hr_sinogram = None

    def train(self, hr_sinogram, lr_sinogram):
        if self.config['model.iterative.reset_every_step']:
            self.model = TomographicReconstruction(prior=torch.zeros(self.im_shape, device=self.device), use_sigmoid=True)
            self.model.eval()
        # optimize model
        if lr_sinogram is not None or hr_sinogram is not None:
            finetune_sinogram_consistency(self.model, 
                                    lr_forward_function=self.lr_forward_function, # low resolution forward model
                                    target_sinogram_lr=lr_sinogram,  # low resolution sinogram
                                    target_sinogram_hr=hr_sinogram,  # high resolution sinogram
                                    steps=self.config['model.iterative.steps'], 
                                    lr=self.config['model.iterative.lr'],
                                    verbose=False)
        self.hr_sinogram = hr_sinogram
        self.lr_sinogram = lr_sinogram

    def predict(self):
        return self.model.get_img(), None


class IterativeBootstrap(IterativeBaseModel):
    def __init__(self, config, device, filter_fct):
        super().__init__(config, device, filter_fct, name='Iter/Bootstrap')


    def predict(self, num_samples=20, reference_fct='mean'):
        if self.hr_sinogram is None and self.lr_sinogram is None:
            samples = torch.randn(num_samples, *self.im_shape, device=self.device)
        else:
            samples = sample_ensemble( 
                            lr_sinogram=self.lr_sinogram, 
                            hr_sinogram=self.hr_sinogram, 
                            im_shape=self.im_shape,
                            lr_forward_function=self.lr_forward_function, 
                            num_samples=num_samples, 
                            input_std=self.config['model.bootstrap.input_std'], 
                            batch_size=self.config['model.bootstrap.batch_size'],
                            lr=self.config['model.bootstrap.lr'],
                            steps=self.config['model.bootstrap.steps'],
                            device=self.device)
        
        reconstruction = self.model.get_img()
        # reference = get_reference(samples, hr_sinogram=self.hr_sinogram, reference_fct=reference_fct)
        return reconstruction, samples
    

class IterativeSWAG(IterativeBaseModel):
    def __init__(self, config, device, filter_fct):
        super().__init__(config, device=device, filter_fct=filter_fct, name='Iter/SWAG')

    def predict(self, num_samples=20):
        if self.hr_sinogram is None and self.lr_sinogram is None:
            samples = torch.randn(num_samples, *self.im_shape, device=self.device)
        else:
            samples = sample_swag(self.model, 
                           lr_sinogram=self.lr_sinogram, 
                           hr_sinogram=self.hr_sinogram, 
                           lr_forward_function=self.lr_forward_function, 
                           num_samples=num_samples, 
                           subsample=self.config['model.swag.subsample'])
        reconstruction = self.model.get_img()
        return reconstruction, samples
    

class Laplace(GenerativeModel):
    
    def __init__(self, config, device, filter_fct, ray_trafo, hydra_cfg):
        super().__init__(config, device, filter_fct, name='DIP/Laplace')
        self.ray_trafo = ray_trafo
        self.hydra_cfg = hydra_cfg
        self.reconstructor = get_unet_reconstructor(im_shape=self.im_shape, cfg=hydra_cfg, device=self.device)
        self.rand_unet_input = 0.1 * torch.randn(1, 1, *ray_trafo.im_shape, device=self.device)
        self.ray_trafo = ray_trafo
        self.linearized_weights = None
        self.linearized_recon =  torch.zeros(self.im_shape, device=self.device).unsqueeze(0).unsqueeze(0)
    
    def train(self, obs, ray_trafo, unet_input=None, optimize_iterations=2, ground_truth=None):
        if self.config['model.laplace.reset_every_step']:
             self.reconstructor = get_unet_reconstructor(im_shape=self.im_shape, cfg=self.hydra_cfg, device=self.device)
        if unet_input is None:
            unet_input = self.rand_unet_input
        train_iterations = self.config['model.laplace.train_iterations']

        if len(obs.squeeze()) > 0:
            self.dip_recon = train_unet(self.reconstructor, 
                                        unet_input=unet_input,
                                        cfg=self.hydra_cfg, 
                                        observation=obs, 
                                        ray_trafo=ray_trafo, 
                                        iterations=train_iterations, 
                                        lr=self.config['model.laplace.train_lr'],
                                        ground_truth=ground_truth,
                                        verbose=self.config['verbose'])
        else:
            with torch.no_grad():
                self.dip_recon = self.reconstructor.nn_model(unet_input)
        
        
        self.posterior, self.optim_kwargs = get_posterior(self.reconstructor,
                                                cfg=self.hydra_cfg, 
                                                ray_trafo=ray_trafo,
                                                unet_input=unet_input, 
                                                device=self.device)
        if optimize_iterations and len(obs.squeeze()) > 0:
            self.linearized_weights, self.linearized_recon = optimize_unet(
                recon=self.dip_recon, 
                predictive_posterior=self.posterior, 
                prev_linear_weights=None, 
                observation=obs, 
                cfg=self.hydra_cfg, 
                optim_kwargs=self.optim_kwargs, 
                iterations=optimize_iterations,
                ground_truth=None
            )             
        
    def predict(self, num_samples=20):
        if self.linearized_weights is None:
            samples = torch.randn(num_samples, *self.im_shape, device=self.device)
        else:
            samples = sample_unet(predictive_posterior=self.posterior, 
                              optim_kwargs=self.optim_kwargs, 
                              sample_then_optimize=False,
                              num_samples=num_samples)
        
        if self.config['model.laplace.reconstruction'] == 'linearized_recon':
            return self.linearized_recon.squeeze(), self.linearized_recon.squeeze() + samples.squeeze()
        elif self.config['model.laplace.reconstruction'] == 'dip_recon':
            return self.dip_recon.squeeze(),  self.dip_recon.squeeze() + samples.squeeze()
        else:
            raise ValueError(f"Unknown reconstruction type '{self.config['model.laplace.reconstruction']}'")


def get_testset(dataset, lr_forward_function, path, im_size):
    assert im_size in [128, 512]
    kwargs = {
        'lr_forward_function': lr_forward_function,
        'gray_background': False,
        'train_transform': False,
        'to_gray': False,
        'rotation_angle': 30,
        'normalize_range': True,
        'rescale': im_size,
    }

    if dataset == 'Lamino':
        kwargs.update({
            'path': os.path.join(path, 'lamino_tiff'),
            'im_size': 256 if im_size == 128 else 512,
        })
        trainSet, testSet = get_dataset(kwargs, 'tiff')

    elif dataset == 'Chip':
        # Load chip dataset
        kwargs.update({
            'path': os.path.join(path, 'DATASET_G7_170um_10nm_rect'), #os.path.join(DATA_PATH, 'tomograms_blueprint.h5'),
            'im_size':256 if im_size == 128 else 512,
            'normalize_range':True
        })
        trainSet, testSet = get_dataset(kwargs, 'tiff')
    elif dataset == 'Lung':
        # Load lung dataset
        kwargs.update({
            'path': os.path.join(path, 'lung/ground_truth_train'),
        })
        trainSet, testSet = get_dataset(kwargs, 'h5')
    elif dataset == 'Composite':
        kwargs.update({
            'path': os.path.join(path, 'composite/SampleG-FBI22-Stitch-0-1-2.txm.nii'),
            'im_size':256 if im_size == 128 else 512,
            'file_range':[20,360],
            'clip_range':[3e4, 5e4]
        })
        trainSet, testSet = get_dataset(kwargs, 'nii')
    else:
        raise ValueError(f"Unknown dataset '{dataset}'")
    
    return testSet


def run_active_learning(prior, target, model, config, tracker : Tracker, device):
    # run = wandb.run
    # gaussian_filter = create_gaussian_filter(size=512, sigma=config['model.kernel_sigma'])

    tracker.log_file('prior', prior)
    tracker.log_file('target', target)

    # available angles
    angles = get_uniform_angles(config['experiment.num_angles'], device=device)
    tracker.log_file('angles', angles)

    # full sinogrma
    hr_sinogram_full = Sinogram(compute_sinogram(target, angles), angles=angles)

    if config['experiment.obs_noise_rel_std']:
        noise_std = config['experiment.obs_noise_rel_std'] * target.shape[-1]
        hr_sinogram_full.sinogram += noise_std * torch.randn_like(hr_sinogram_full.sinogram)

    lr_sinogram = None
    if config['acquisition.init.num_lr_angles'] is not None:
        lr_angles = get_uniform_angles(config['acquisition.init.num_lr_angles'], device=device)
        lr_sinogram = Sinogram(compute_sinogram(prior, lr_angles), angles=lr_angles)

    # special treatment for Laplace framework
    if config['model'] == 'Laplace':
        dip_sinogram_full = model.ray_trafo(target.unsqueeze(0).unsqueeze(0))
        if config['experiment.obs_noise_rel_std']:
            noise_std = config['experiment.obs_noise_rel_std'] * target.shape[-1]
            dip_sinogram_full += noise_std * torch.randn_like(dip_sinogram_full)

        subset_ray_trafo = SubsetRayTrafo(model.ray_trafo)

        unet_input = None
        if config['model.laplace.unet_input'] == 'lr_image':
            unet_input = prior.unsqueeze(0).unsqueeze(0)
        elif config['model.laplace.unet_input'] == 'fbp':
            unet_input = sklearn_fbp(lr_sinogram.sinogram.cpu(), lr_sinogram.angles.cpu()).unsqueeze(0).unsqueeze(0).to(device)
    
    selected_indices = []

    if config['acquisition.init.num_hr_angles'] is not None:
        selected_indices = np.linspace(0, len(angles), config['acquisition.init.num_hr_angles'], endpoint=False, dtype=int).tolist()
        
    # run active learning
    for i in tqdm(range(config['experiment.num_iterations']), desc="Iteration", position=1, leave=False):

        if len(selected_indices) > 0:
            hr_sinogram = filter_sinogram(hr_sinogram_full, selected_indices)
        else:
            hr_sinogram = None

        if config['model'] == 'Laplace':
            # special data generation for Laplace model
            mask = get_dip_mask(our_angles=angles[selected_indices], ray_trafo=model.ray_trafo)
            subset_ray_trafo.set_angular_mask(mask)
            dip_sinogram = dip_sinogram_full[:,:,mask]

            t0 = time.time()
            
            model.train(dip_sinogram, ray_trafo=subset_ray_trafo, unet_input=unet_input, 
                        optimize_iterations=config['model.laplace.optimize_iterations'], ground_truth=target.unsqueeze(0).unsqueeze(0))
            time_train = time.time() - t0
        else:
            t0 = time.time()
            model.train(hr_sinogram, lr_sinogram)
            time_train = time.time() - t0

        t0 = time.time()
        reconstruction, samples = model.predict(num_samples=config['acquisition.num_samples'])
        time_predict = time.time() - t0

        with torch.no_grad():
            # compute scores
            if config['acquisition.score'] == 'variance':
                scores = sampled_var(compute_sinogram(samples - torch.mean(samples, dim=0), angles))
            elif config['acquisition.score'] == 'committee':
                scores = sampled_var(compute_sinogram(samples - reconstruction, angles))
            elif config['acquisition.score'] == 'gauss_entropy':
                scores = sampled_expected_info_gain(compute_sinogram(samples - reconstruction, angles),
                                                    noise_obs_std=config['acquisition.noise_std'])
            elif config['acquisition.score'] == 'uniform':
                scores = -torch.tensor(uniform_scores(len(angles)), dtype=torch.get_default_dtype())

        # select angles
        top_k = choose_top_k(scores, angles, k=1, exclude_indices=selected_indices)
        selected_indices += top_k

        assert len(top_k) == 1

        tracker.log(iteration=i, 
                    selected_angle=angles[top_k[0]],
                    time_train=time_train,
                    time_predict=time_predict,
                    **get_metrics(target, reconstruction))
        tracker.log_file('sample_1', samples[0], iteration=i)
        tracker.log_file('sample_2', samples[1], iteration=i)
        tracker.log_file('sample_3', samples[2], iteration=i)
        tracker.log_file('std', torch.std(samples, dim=0), iteration=i)
        tracker.log_file('reference', reconstruction, iteration=i)
        tracker.log_file('scores', scores, iteration=i)

        if config['model'] == 'Laplace':
            tracker.log_file('linearized_recon', model.linearized_recon, iteration=i)
            tracker.log_file('dip_recon', model.dip_recon, iteration=i)

        tracker.save(complete=False)

        gc.collect()
        torch.cuda.empty_cache()


def main():
    # parser = get_default_argument_parser(description="Diffusion Active Learning Experiments")

    parser = FlaggedArgumentParser(description="Diffusion Active Learning Experiments")
    parser.add_argument("--device", type=str, default=None, help="torch device")
    parser.add_argument("--verbose", action="store_true" , help="")

    # dataset arguments
    parser.add_argument("--dataset", type=str, default='Chip', choices=['Chip', 'Lung', 'Lamino', 'Composite'], flag=True)
    parser.add_argument("--dataset.im_size", type=int, default=128, help="Image size", flag=True)
    parser.add_argument("--dataset.circle_filter_radius", type=int, default=15, help="circle filter radius", flag=True)
    parser.add_argument("--dataset.path", type=str, default='/mydata/chip/shared/data', help="data path", flag=True)
    parser.add_argument("--dataset.indices", type=lambda x : [*map(int, x.split(","))], required=True, help="Indices of the dataset to use")

    # experiment arguments
    parser.add_argument('--experiment', type=str, help='active learning number of iterations', flag=True)
    parser.add_argument('--experiment.num_iterations', type=int, default=10, help='active learning number of iterations', flag=True)
    parser.add_argument('--experiment.num_angles', type=int, default=180, help='number of angles available for active learning', flag=True)
    parser.add_argument('--experiment.obs_noise_rel_std', type=float, default=None, help='observation noise', flag=True)

    # model arguments
    parser.add_argument("--model", type=str, choices=['Diffusion', 'Laplace', 'Swag', 'Bootstrap'], flag=True)
    parser.add_argument("--model.path", type=str, default='checkpoints/diffusion_model_tomogram.pt', help="Path to model checkpoint", flag=True)
    
    # parser.add_argument("--model.diffusion.lr", type=float, default=0.05, help="", flag=True)
    parser.add_argument("--model.diffusion.batch_size", type=int, default=10, help="", flag=True)
    parser.add_argument("--model.diffusion.use_sigmoid", action="store_true" , help="", flag=True)
    parser.add_argument("--model.diffusion.buffer", type=int, default=5, help="", flag=True)
    parser.add_argument("--model.iterative.steps", type=lambda x : [*map(int, x.split(","))], default="400,400,200", help="Number of steps for iterative reconstruction", flag=True)
    parser.add_argument("--model.iterative.lr", type=lambda x : [*map(float, x.split(","))], default="0.1,0.01,0.001", help="Learning rate for iterative reconstruction", flag=True)
    parser.add_argument("--model.iterative.reset_every_step", action="store_true" , help="", flag=True)
    parser.add_argument("--model.swag.subsample", type=int, default=10, help="Subsampling for SWAG", flag=True)
    parser.add_argument("--model.bootstrap.input_std", type=float, default=0.01, help="", flag=True)
    parser.add_argument("--model.bootstrap.batch_size", type=int, default=4, help="", flag=True)
    parser.add_argument("--model.bootstrap.steps", type=lambda x : [*map(int, x.split(","))], default="200,200,100", help="Number of steps for iterative reconstruction", flag=True)
    parser.add_argument("--model.bootstrap.lr", type=lambda x : [*map(float, x.split(","))], default="0.1,0.01,0.001", help="Learning rate for iterative reconstruction", flag=True)
    parser.add_argument("--model.laplace.use_sigmoid", action="store_true" , help="", flag=True)
    parser.add_argument("--model.laplace.train_iterations", type=int, default=2000 , help="", flag=True)
    parser.add_argument("--model.laplace.train_lr", type=float, default=3e-4 , help="", flag=True)
    parser.add_argument("--model.laplace.reset_every_step", action="store_true" , help="", flag=True)
    parser.add_argument("--model.laplace.fbp_adjoint", action="store_true" , help="", flag=True)
    parser.add_argument("--model.laplace.optimize_iterations", type=int, default=1 , help="", flag=True)
    parser.add_argument("--model.laplace.matrix_path", type=str, default=None , help="", flag=True)
    parser.add_argument("--model.laplace.hydra_path", type=str, default=None , help="", flag=True)
    parser.add_argument("--model.laplace.reconstruction", type=str, default='linearized_recon', 
                        choices=['linearized_recon', 'dip_recon'] , help="", flag=True)
    parser.add_argument("--model.laplace.unet_input", type=str, default=None, 
                        choices=['lr_image', 'fbp'] , help="", flag=True)
    
    # parser.add_argument("--model.batch_size", default=10, type=int, help="batch size of generative process")
    # parser.add_argument("--model.diffusion_steps", default=50, type=int, help="Num of diffusion steps")
    # parser.add_argument("--model.sgd_steps", type=int, default=50, help="Num of SGD steps in iterative reconstruction")
    # parser.add_argument("--model.with_finetuning", action='store_true',
                        # help="Use fine tuning at the end of the diffusion to match the hr sinogram projections better")
    # parser.add_argument("--model.fourier_inpainting", action='store_true',
                        # help="Use lr data to do inpainting of low frequencies in fourier space")
    # parser.add_argument("--model.learning_rate", default=0.05, type=float, help="Learning rate for iterative steps ")
    # parser.add_argument("--model.kernel_sigma", default=5, type=float, help="Variance of gaussian kernel")
    # parser.add_argument("--model.frequency_cut_out_radius", default=-1, type=float, help="Radius to use in fourier space for the cutout to get low resolution images")
    # parser.add_argument("--model.inpainting_range", default=100, type=int, help="The maximum step in the diffusion in which we still do inpainting of low frequencies")
    # parser.add_argument("--model.use_iterative_reconstruction", action='store_true',
                        # help="Use the reconstruction of the iterative model instead of that of the diffusion for the logging")
    # parser.add_argument("--model.fourier_magnitude", action='store_true',
                        # help="Do the fourier inpainintg of only the magnitude")

    # parser.add_argument("--model.iterative_reconstruction", action='store_true', help="Use iterative reconstruction to estimate the image")
    # parser.add_argument("--model.ir_sigma", default=8, type=float)
    # parser.add_argument("--model.ir_kernel_size", default=21, type=int)

    # algorithm arguments
    # parser.add_argument("--acquisition.initialization", type=str, default=None, choices=['sinogram_total_variation'],
    #                     help="Start with angles ot max total variation in lr sinogram")

    parser.add_argument("--acquisition.init.num_lr_angles", type=int, default=None, help="Initial angles for low resolution sinogram", flag=True)
    parser.add_argument("--acquisition.init.num_hr_angles", type=int, default=None, help="Initial angles for high resolution sinogram", flag=True)
    parser.add_argument("--acquisition.noise_std", type=float, default=0.1, help="Noise std", flag=True)
    parser.add_argument("--acquisition.num_samples", type=int, default=20, help="Noise std", flag=True)
    parser.add_argument("--acquisition.score", type=str, default='variance', flag=True, 
                        choices=['gauss_entropy', 'gradient_norm', 'variance', 'uniform', 'random', 'committee'])
    # parser.add_argument("--acquisition.allow_duplicates", action='store_true', help="")
    # parser.add_argument("--acquisition.radius", type=float, default=None)

    # logging arguments
    # parser.add_argument("--logging.wandb.image_logging", action='store_true', help="Log images during the run")
    # parser.add_argument("--logging.wandb", action='store_true', help="Use wandb logging")
    parser.add_argument("--logging.path", type=str, help="Path for local logging.")
    parser.add_argument("--logging.slurm_id", type=str, default=None, help="Path for local logging.")

    # parse configs
    config = vars(parser.parse_args())
    config_flagged = vars(parser.parse_args(only_flagged=True))

    im_size = config['dataset.im_size']
    if config['device'] is None:
        config['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    device = config['device']
    print(f"Using device {device}")

    # model = get_diffusion_model(device)
    # load_model(model, config['model.path'])
    # model.eval()

    # score_fct = {
    #     'gauss_entropy': sampled_EIG,
    #     'variance' : sampled_var,
    #     # 'gradient_norm': gradient_norm,
    #     # 'uniform' : uniform,
    #     # 'random' : random
    # }[config['acquisition.score']]

    circle_filter = create_circle_filter(radius=config['dataset.circle_filter_radius'], size=im_size)
    gaussian_filter = create_gaussian_filter(sigma=10, size=config['dataset.im_size'])
    current_filter = circle_filter
    lr_forward_function = lambda x: fourier_filtering(x, current_filter)
    
    dataset = get_testset(config['dataset'], lr_forward_function, config['dataset.path'], im_size=im_size)


    if config['model'] == 'Laplace':
        from hydra import compose, initialize, initialize_config_dir
        from bayes_dip.utils.experiment_utils import get_standard_ray_trafo

        initialize_config_dir(version_base=None, config_dir=config['model.laplace.hydra_path'], job_name="test_app")
        # initialize hydra
        cfg = compose(config_name="config", overrides=[
            "experiment=chip",
            "dip.optim.iterations=2000",
            "dip.optim.gamma=1e-4",
            f"dip.net.use_sigmoid={config['model.laplace.use_sigmoid']}",
            "dataset.noise_stddev=0.0",
            f"dataset.im_size={im_size}",
            "mll_optim=walnut_sample_based_mll_optim",
            "mll_optim.activate_debugging_mode=False",
            f"mll_optim.num_samples=10",
            f"mll_optim.sampling.batch_size=10",
            "mll_optim.sampling.use_conj_grad_inv=true",
            # "mll_optim.use_sample_then_optimise=True",
            "priors.use_gprior=True",
            "priors.gprior.scale.obs_subsample_fct=10",
            "trafo.num_angles=180",
            f"trafo.matrix_ray_trafo={not config['model.laplace.fbp_adjoint']}",
            f"trafo.matrix_path={config['model.laplace.matrix_path']}",
            f"trafo.geometry_specs.num_det_pixels={im_size}"
        ])
        ray_trafo = get_standard_ray_trafo(cfg)
        ray_trafo.to(dtype=torch.get_default_dtype(), device=device)


    for index in tqdm(config['dataset.indices']):
        path = os.path.join(config['logging.path'], config['experiment'], config['dataset'], config['model'])

        try:
            tracker = Tracker(config=config_flagged, path=path, index=index, delete_existing=config['experiment'] == 'debug')
        except FileExistsError as e:
            print(e)
            continue

        if config['model'] == 'Laplace':
            model = Laplace(config, device, current_filter, hydra_cfg=cfg, ray_trafo=ray_trafo)
        elif config['model'] == 'Diffusion':
            model = Diffusion(config, device, filter_fct=current_filter)
        elif config['model'] == 'Swag':
            model = IterativeSWAG(config, device, filter_fct=current_filter)
        elif config['model'] == 'Bootstrap':
            model = IterativeBootstrap(config, device, filter_fct=current_filter)
        else:
            raise ValueError(f"Unknown model '{config['model']}'")

        prior, target, _ = dataset[index]
        print(f"running active learning on index {index}...")
        t0 = time.time()
        run_active_learning(prior=prior.to(device).squeeze(), 
                            target=target.to(device).squeeze(), 
                            model=model, 
                            config=config,
                            tracker=tracker,
                            device=device)

        tracker.log_meta(total_time=time.time() - t0)
        tracker.log_meta(slurm_id=config['logging.slurm_id'])
        tracker.save(complete=True)


if __name__ == '__main__':
    main()