import torch
from diffusers import UNet2DModel
from chip.models.tomographic_diffusion import TomographicDiffusion
from diffusers import DDPMScheduler

def get_diff_unet_model(checkpoint_path, im_size=128, device='cuda'):
    model = UNet2DModel(
            sample_size=im_size,  # the target image resolution
            in_channels=1,  # the number of input channels, 3 for RGB images
            out_channels=1,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
            down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",  # a regular ResNet upsampling block
                "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        ).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"model loaded from checkpoint {checkpoint_path}")
    return model


def get_diffusion_samples(model, lr_forward_function, hr_sinogram=None, lr_sinogram=None, batch_size=10, device=None, verbose=False, num_samples=20, buffer=5, use_sigmoid=False):
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    t_start = 999

    target = torch.zeros((model.sample_size, model.sample_size))
    x_t = noise_scheduler.add_noise(
        target.repeat(num_samples, 1, 1, 1), 
        torch.randn_like(target.repeat(num_samples, 1, 1, 1)), 
        torch.LongTensor([t_start])
    ).to(device)

    td = TomographicDiffusion(
        (num_samples, model.sample_size, model.sample_size), model, use_sigmoid=use_sigmoid, 
        buffer=buffer, fourier_magnitude=None
    ).to(device)

    images = td.guided_diffusion_pipeline(
        x_t, t_start, 0, noise_scheduler, 50,
        hr_sinogram, 
        lr_sinogram, 
        lr_forward_function=lr_forward_function,
        batch_size=batch_size, #min(10, len(hr_sinogram)), 
        verbose=verbose, sgd_steps=[40, 20], lr=[0.05, 0.01], 
        with_finetuning=False
    )

    return images
