import wandb
import gc
import itertools

from chip.evaluation.active_learning import mean_gradient_norm, entropy, choose_top_k, angular_distance
from chip.evaluation.benchmark import active_learning_benchmark, log_iteration, get_default_argument_parser
from chip.evaluation.tr_active_learning import get_iterative_reconstruction
from chip.utils import create_gaussian_filter
from chip.utils.metrics import get_metrics
from chip.utils.plotting import plot_sinogram, plot_comparison
from chip.utils.sinogram import batched_sinogram, compute_sinogram, Sinogram

import lovely_tensors as lt
lt.monkey_patch()

import matplotlib.pyplot as plt
import torch
from chip.utils.utils import load_model, save_image_and_log
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur

from diffusers import UNet2DModel
from diffusers import DDPMScheduler

from chip.models.tomographic_diffusion import TomographicDiffusion
from chip.models.iterative_model import TomographicReconstruction

from chip.training.iterative_reconstruction import total_variation_loss, finetune_sinogram_consistency
from tqdm.auto import tqdm

import warnings

# Ignore UserWarning
warnings.filterwarnings('ignore', category=UserWarning)


def get_diffusion_model(device=None):
    # initalize model
    model = UNet2DModel(
        sample_size=512,  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )
    if device is not None:
        model = model.to(device)
    return model


def get_diffusion_samples(model, source, target_sinogram, gaussian_filter, num_samples=10, verbose=True):
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    t_start = 999

    # produces noise and uses the scheduler to scale it appropiately
    x_t = noise_scheduler.add_noise(
        torch.zeros_like(target).repeat(num_samples, 1, 1, 1),
        torch.randn_like(target.repeat(num_samples, 1, 1, 1)),
        torch.LongTensor([t_start])
    ).to(device)

    td = TomographicDiffusion(
        (num_samples, 512, 512),
        model, use_sigmoid=True,
        buffer=5
    ).to(device)

    images = td.guided_diffusion_pipeline(
        x_t,
        t_start, 0, noise_scheduler, config['model.diffusion_steps'],
        target_sinogram.projections, target_sinogram.angles,
        batch_size=min(30, len(target_sinogram)),
        verbose=verbose, sgd_steps=config['model.sgd_steps'], lr=config['model.learning_rate'],
        with_finetuning=config['model.with_finetuning'],
        fourier_inpainting=config['model.fourier_inpainting'],
        lr_tomogram=source.to(device),
        frequency_cut_out_radius=config['model.frequency_cut_out_radius'],
        inpainting_range=config['model.inpainting_range']
    )

    return images

batched_total_variation = torch.vmap(total_variation_loss)

def total_variation_initialization(source, angles):
    with torch.no_grad():
        tr = TomographicReconstruction(source.to(device), False).to(device)
        lr_sinogram = tr.forward(torch.arange(180).to(device))

        sinogram_variation = batched_total_variation(lr_sinogram.unsqueeze(1)).cpu().detach()
        main_angle = torch.argmax(sinogram_variation).int().item()

        second_main_angle = torch.argmin(angular_distance(angles, (main_angle + 90) % 180)).int().item()

        # plot_sinogram(lr_sinogram.T.cpu())
        # plt.plot(200 * sinogram_variation / torch.max(sinogram_variation), alpha=0.4)
        # file = f"tmp/lr_sinogram.png"
        # save_image_and_log(file, "lr_sinogram", round)
    return [main_angle, second_main_angle]


def select_best_sample(samples, prior, obs_sinogram):
    """
    Among a set of samples, choose the one with the smallest prediciton error.
    """
    # error for fitting the prior image
    errors = torch.sum((samples - prior) ** 2, dim=(1, 2))

    # add error for fitting the hr sinogram
    if len(obs_sinogram) > 0:
        pred_singorams = batched_sinogram(samples, angles=obs_sinogram.angles)
        errors += torch.sum((obs_sinogram.projections - pred_singorams) ** 2, dim=(1, 2))
    return samples[torch.argmin(errors)]


def get_diffusion_predictions(model, source, hr_sinogram, gaussian_filter, num_samples):
    samples = get_diffusion_samples(model, source, hr_sinogram, num_samples=num_samples, gaussian_filter=gaussian_filter, verbose=False).squeeze(dim=1)
    mean_image = torch.mean(samples, dim=0)
    mle_image = select_best_sample(samples, source, hr_sinogram)
    return samples, mean_image, mle_image


def run_active_learning(prior, target, model, config):
    run = wandb.run
    gaussian_filter = create_gaussian_filter(size=512, sigma=config['model.kernel_sigma'])

    if config['wandb.image_logging']:
        fig, ((ax1, ax2)) = plt.subplots(1, 2, figsize=(10, 10))
        ax1.imshow(prior.cpu(), cmap='gray')
        ax2.imshow(target.cpu(), cmap='gray')
        file = f"tmp/source_target_{run.id}.png"
        save_image_and_log(file, "source_target", commit=False, delete=True)

    # available angles
    num_angles = config['experiment.theta']
    theta = torch.linspace(0, 180 * (1 - 1 / num_angles), num_angles, device=device)

    low_res_filter = lambda image: torch_gaussian_blur(image.unsqueeze(0), kernel_size=config['model.ir_kernel_size'], sigma=config['model.ir_sigma']).squeeze(0)
    hr_sinogram = Sinogram()
    selected_indices = []

    tr_prior = prior

    # run initialization scheme
    if config['acquisition.initialization'] == 'sinogram_total_variation':
        # raise NotImplementedError
        initial_indices = total_variation_initialization(prior, theta)
        selected_indices += initial_indices
        top_k_angles = theta[selected_indices]
        observation = compute_sinogram(target, top_k_angles)

        hr_sinogram.add(top_k_angles, observation)

        samples, mean_image, mle_image = get_diffusion_predictions(
            model, prior, hr_sinogram, gaussian_filter, config['model.batch_size']
        )

        if config['model.iterative_reconstruction']:
            # overwrite the mle_prediction using iterative reconstruction
            tr, loss = get_iterative_reconstruction(tr_prior, prior, hr_sinogram, low_res_filter)
            tr_prior = mle_image_alt = tr.get_img().detach()
            wandb.log({ f'iterative_reconstruction_{k}' : v for k, v in get_metrics(target, mle_image_alt).items()}, commit=False)
            wandb.log({"iterative_reconstruction_loss": loss}, commit=False)


        log_iteration(top_k_angles, mle_image, target, commit=True)
    else:
        samples, mean_image, mle_image = get_diffusion_predictions(
            model, prior, hr_sinogram, gaussian_filter, config['model.batch_size']
        )

    # run active learning
    for i in tqdm(range(config['experiment.num_iterations']), desc="Iteration", position=1, leave=False):

        if config['acquisition.score'] == 'max_entropy':
            scores, entropy_sinogram, angle_entropy = entropy(samples, theta)

            if config['wandb.image_logging']:
                plot_sinogram(Sinogram(entropy_sinogram, theta))
                plt.plot(512 - angle_entropy.cpu().numpy())
                file = f"tmp/entropy_sinogram_{run.id}.png"
                save_image_and_log(file, "entropy_sinogram", commit=False, delete=True)

        if config['acquisition.score'] == 'mean_gradient_norm':
            tr = TomographicReconstruction(mean_image, use_sigmoid=True).to(device)
            scores = mean_gradient_norm(tr, samples, theta)

        # select angles
        top_k = choose_top_k(scores, theta,
                             k=config['experiment.batch_size'],
                             radius=config['acquisition.radius'],
                             exclude_indices=None if config['acquisition.allow_duplicates'] else selected_indices)

        # do not re-evaluate angles
        top_k = [i for i in top_k if i not in selected_indices]

        if len(top_k) > 0:
            top_k_angles = theta[top_k]
            obs = compute_sinogram(target.to(device), top_k_angles)
            hr_sinogram.add(top_k_angles, obs)
            selected_indices += top_k
        else:
            print(f"Warning: No new angles selected in iteration {i}")


        # generate new predictions
        samples, mean_image, mle_image = get_diffusion_predictions(
            model, prior, hr_sinogram, gaussian_filter, config['model.batch_size']
        )

        if config['wandb.image_logging']:
            # log diffusion samples
            fig, (row1, row2) = plt.subplots(2, 3, figsize=(10, 7))
            for i, ax in enumerate(itertools.chain(row1, row2)):
                ax.axis(False)
                ax.imshow(samples[i].cpu().detach(), cmap='gray')
            file = f"tmp/generated_images_{run.id}.png"
            save_image_and_log(file, "generated_images", commit=False, delete=True)
            
            # log direct comparison
            fig, axes = plot_comparison(mle_image.cpu().detach(), target.cpu().detach(), prior.cpu().detach())
            file = f"tmp/comparison_{run.id}.png"
            save_image_and_log(file, "comparison", commit=False, delete=True)

        if config['model.iterative_reconstruction']:
            # log values for the mle_prediction using iterative reconstruction
            tr, loss = get_iterative_reconstruction(tr_prior, prior, hr_sinogram, low_res_filter)
            tr_prior = mle_image_alt = tr.get_img().detach()
            wandb.log({f'iterative_reconstruction_{k}': v for k, v in get_metrics(target, mle_image_alt).items()}, commit=False)
            wandb.log({"iterative_reconstruction_loss": loss}, commit=False)

            if config['wandb.image_logging']:
                # log direct comparison
                fig, axes = plot_comparison(mle_image_alt.cpu().detach(), target.cpu().detach(), prior.cpu().detach())
                file = f"tmp/comparison_{run.id}.png"
                save_image_and_log(file, "iterative_reconstruction_comparison", commit=False, delete=True)

        # log standard metrics
        log_iteration(top_k_angles, mle_image, target, commit=True)

        gc.collect()
        torch.cuda.empty_cache()


if __name__ == '__main__':
    parser = get_default_argument_parser(description="Diffusion Active Learning Experiments")

    # model arguments
    parser.add_argument("--model.path", type=str, default='checkpoints/diffusion_model_tomogram.pt', help="Path to model checkpoint")
    parser.add_argument("--model.batch_size", default=10, type=int, help="batch size of generative process")
    parser.add_argument("--model.diffusion_steps", default=50, type=int, help="Num of diffusion steps")
    parser.add_argument("--model.sgd_steps", type=int, default=50, help="Num of SGD steps in iterative reconstruction")
    parser.add_argument("--model.with_finetuning", action='store_true',
                        help="Use fine tuning at the end of the diffusion to match the hr sinogram projections better")
    parser.add_argument("--model.fourier_inpainting", action='store_true',
                        help="Use lr data to do inpainting of low frequencies in fourier space")
    parser.add_argument("--model.learning_rate", default=0.05, type=float, help="Learning rate for iterative steps ")
    parser.add_argument("--model.kernel_sigma", default=5, type=float, help="Variance of gaussian kernel")
    parser.add_argument("--model.frequency_cut_out_radius", default=-1, type=float, help="Radius to use in fourier space for the cutout to get low resolution images")
    parser.add_argument("--model.inpainting_range", default=100, type=int, help="The maximum step in the diffusion in which we still do inpainting of low frequencies")
    parser.add_argument("--model.use_iterative_reconstruction", action='store_true',
                        help="Use the reconstruction of the iterative model instead of that of the diffusion for the logging")
    parser.add_argument("--model.fourier_magnitude", action='store_true',
                        help="Do the fourier inpainintg of only the magnitude")

    parser.add_argument("--model.iterative_reconstruction", action='store_true', help="Use iterative reconstruction to estimate the image")
    parser.add_argument("--model.ir_sigma", default=8, type=float)
    parser.add_argument("--model.ir_kernel_size", default=21, type=int)

    # algorithm arguments
    parser.add_argument("--acquisition.initialization", type=str, default=None, choices=['sinogram_total_variation'],
                        help="Start with angles ot max total variation in lr sinogram")

    parser.add_argument("--acquisition.score", type=str, default='max_entropy', choices=['max_entropy', 'mean_gradient_norm'])
    parser.add_argument("--acquisition.allow_duplicates", action='store_true', help="")
    parser.add_argument("--acquisition.radius", type=float, default=None)

    # logging arguments
    parser.add_argument("--wandb.image_logging", action='store_true', help="Log images during the run")

    config = vars(parser.parse_args())

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    model = get_diffusion_model(device)
    load_model(model, config['model.path'])
    model.eval()


    for prior, target in active_learning_benchmark(config):
        print("running active learning...")
        run_active_learning(prior.to(device), target.to(device), model, config)