import torch
import torch.nn as nn
from tqdm.auto import tqdm

from diffusers import UNet2DModel
from torch import optim
from chip.utils.fourier import b_fft_2D, b_ifft_2D
from chip.utils.projections import Projections, project_2D, laminography_projection

bb_fft_2D = torch.vmap(b_fft_2D)
bb_ifft_2D = torch.vmap(b_ifft_2D)


def probe_image_product(image, probes, probe_indices=None):
    if probe_indices is None:
        return torch.stack([probe * image for probe in probes], 1)
    else:
        return torch.stack([probes[i] * image for i in probe_indices], 1)


class PtychoTomographicDiffusion(nn.Module):
    def __init__(
            self, volume_shape, unet: UNet2DModel
    ):
        super(PtychoTomographicDiffusion, self).__init__()
        self.unet = unet
        self.bs = volume_shape[0]
        self.x_t = nn.Parameter(torch.zeros(volume_shape))

        w = volume_shape[-1]
        cp = torch.cartesian_prod(torch.arange(w), torch.arange(w))
        circle_mask = (cp[:, 0] - w / 2) ** 2 + (cp[:, 1] - w / 2) ** 2 <= (w / 2) ** 2
        self.register_buffer("circle_mask", circle_mask.repeat(self.bs, 1).reshape(self.x_t.shape))

    def predict_x_0(self, t, x_t, noise_scheduler):
        device = self.x_t.device
        timesteps = torch.LongTensor([t]).to(device)
        noise_pred = self.unet(x_t, timesteps, return_dict=False)[0]
        fading_factor = noise_scheduler.add_noise(torch.ones(1), torch.zeros(1), timesteps).to(device)
        noise_factor = noise_scheduler.add_noise(torch.zeros(1), torch.ones(1), timesteps).to(device)
        x_0_pred = (x_t - noise_factor[:, None, None, None] * noise_pred) / fading_factor[:, None, None, None]
        return noise_pred, x_0_pred

    def step(self, t, target_t, x_t, noise_scheduler):
        device = self.x_t.device
        noise_pred, x_0_pred = self.predict_x_0(t, x_t, noise_scheduler)
        new_timestep = torch.LongTensor([target_t]).to(device)
        new_x_t = noise_scheduler.add_noise(x_0_pred, torch.randn_like(x_0_pred), new_timestep).to(device)

        return new_x_t, noise_pred

    def diffusion_pipeline(self, x_t_start, t_start, t_end, noise_scheduler, num_steps=50, verbose=False):
        device = self.x_t.device
        with torch.no_grad():
            x_t = x_t_start.clone().real

            timesteps = torch.linspace(t_start, t_end, num_steps + 1).int()
            for i in tqdm(range(1, len(timesteps)), disable=not verbose):
                t = timesteps[i - 1]
                target_t = timesteps[i]
                x_t, _ = self.step(t, target_t, x_t, noise_scheduler)
            return x_t

    def guided_diffusion_pipeline(
            self, x_t_start, t_start, t_end, noise_scheduler, num_steps=50,
            target_projections_hr: Projections = None,
            target_projections_lr: Projections = None,
            lr_forward_function=None,
            sgd_steps=50,
            batch_size=10,
            lr=0.1,
            buffer=5,
            verbose=False,
            with_projection=False
    ):
        device = self.x_t.device

        x_t = x_t_start.clone()

        if sgd_steps == 0:
            print("using no SGD steps")

        timesteps = torch.linspace(t_start, t_end + buffer, num_steps + 1).int()
        for i in tqdm(range(1, len(timesteps)), disable=not verbose):
            t = timesteps[i - 1]
            target_t = timesteps[i]

            with torch.no_grad():
                noise_pred, x_0_pred = self.predict_x_0(t, x_t.unsqueeze(1), noise_scheduler)
                self.x_t *= 0
                self.x_t += x_0_pred[:, 0]

            if sgd_steps > 0:
                self.ptycho_tomographic_guidance(
                    target_projections_hr,
                    target_projections_lr,
                    lr_forward_function,
                    sgd_steps,
                    batch_size,
                    lr=lr, verbose=False,
                    with_projection=with_projection
                )

            new_timestep = torch.LongTensor([target_t]).to(device)
            x_t = noise_scheduler.add_noise(self.x_t, torch.randn_like(self.x_t),
                                            new_timestep).to(device)

        x_t = self.diffusion_pipeline(
            x_t.unsqueeze(1), t_end + buffer, t_end,
            noise_scheduler, num_steps=buffer,
            verbose=verbose
        ).squeeze(1)

        return x_t

    def ptycho_loss(self, prediction: Projections, target: Projections, indices):
        mse = torch.nn.MSELoss()

        target_projections = target.projections[indices]

        loss = mse(
            target_projections,
            prediction.projections
        )

        complex_recons = target_projections * (torch.cos(prediction.phases) + 1j * torch.sin(prediction.phases))
        missmatch_recons = b_ifft_2D(complex_recons)
        loss += mse(missmatch_recons.real, prediction.real_projections)
        loss += mse(missmatch_recons.imag, torch.zeros_like(missmatch_recons.imag))  # we are assuming real images
        return loss

    def ptycho_tomographic_guidance(
            self,
            target_projections_hr: Projections = None,
            target_projections_lr: Projections = None,
            lr_forward_function=None,
            sgd_steps=50,
            batch_size=10,
            lr=0.1, verbose=False,
            with_projection=True
    ):

        device = self.x_t.device
        # we use the target model to create the target sinograms
        target_projections_hr = target_projections_hr.to(device) if target_projections_hr is not None else None
        target_projections_lr = target_projections_lr.to(device) if target_projections_lr is not None else None

        optimizer = optim.AdamW(
            self.parameters(),
            lr=lr
        )
        mse = torch.nn.MSELoss()
        iterator = tqdm(range(sgd_steps), disable=not verbose)
        for _ in iterator:
            optimizer.zero_grad()
            loss = torch.tensor(0.).to(device)
            if target_projections_lr is not None:
                if batch_size:
                    random_integers = [x.item() for x in torch.randperm(len(target_projections_lr))[:batch_size]]

                # Compute projections of current estimate
                projection_angles = target_projections_lr.angles[random_integers]
                if target_projections_lr.laminography:
                    proj_function = lambda projection_angles, batched_object: laminography_projection(projection_angles, batched_object, target_projections_lr.laminography_tilt_angle)
                else:
                    proj_function = lambda projection_angles, batched_object: project_2D(projection_angles, batched_object)

                projections_lr = Projections(
                    proj_function(projection_angles, lr_forward_function(self.x_t)),
                    projection_angles,
                    fourier_magnitude=target_projections_lr.fourier_magnitude,
                    laminography=target_projections_lr.laminography,
                    laminography_tilt_angle=target_projections_lr.laminography_tilt_angle
                )

                # if in fourier space use ptychographic loss instead
                if target_projections_lr.fourier_magnitude:
                    loss += self.ptycho_loss(projections_lr, target_projections_lr, random_integers)
                else:
                    height = target_projections_lr.projections[random_integers].shape[1] # they might have different hight due to projections
                    loss += mse(projections_lr.projections[:, :height], target_projections_lr.projections[random_integers])

            if target_projections_hr is not None:
                if batch_size:
                    random_integers = [x.item() for x in torch.randperm(len(target_projections_hr))[:batch_size]]

                if target_projections_hr.laminography:
                    proj_function = lambda projection_angles, batched_object: laminography_projection(projection_angles, batched_object, target_projections_hr.laminography_tilt_angle)
                else:
                    proj_function = lambda projection_angles, batched_object: project_2D(projection_angles, batched_object)

                # Compute projections of current estimate
                projection_angles = target_projections_hr.angles[random_integers]
                projections_hr = Projections(
                    proj_function(projection_angles, self.x_t),
                    projection_angles,
                    fourier_magnitude=target_projections_hr.fourier_magnitude,
                    laminography=target_projections_hr.laminography,
                    laminography_tilt_angle=target_projections_hr.laminography_tilt_angle
                )

                # if in fourier space use ptychographic loss instead
                if target_projections_hr.fourier_magnitude:
                    loss += self.ptycho_loss(projections_hr, target_projections_hr, random_integers)
                else:
                    height = target_projections_hr.projections[random_integers].shape[1]  # they might have different hight due to projections
                    loss += mse(projections_hr.projections[:, :height], target_projections_hr.projections[random_integers])

            iterator.set_postfix({"loss": f"{loss.item():.8f}"})

            loss.backward()
            optimizer.step()
            # project to solution space
            if with_projection:
                with torch.no_grad():
                    tmp = (self.x_t.clone() * self.circle_mask).clip(0, 1).cpu()
                    self.x_t *= 0
                    self.x_t += tmp.to(self.x_t.device)

        return loss.item()


if __name__ == '__main__':
    import lovely_tensors as lt
    from chip.utils.utils import create_circle_filter
    from chip.models.forward_models import fourier_filtering

    lt.monkey_patch()

    model = UNet2DModel(
        sample_size=512,  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )

    volume = torch.randn(20, 256, 256)

    gf = create_circle_filter(30, volume.shape)

    projection_angles = torch.arange(180)[::10]
    target_projections_hr = Projections(
        laminography_projection(projection_angles, volume, 20.), projection_angles, fourier_magnitude=True, laminography=True, laminography_tilt_angle=20.
    )
    target_projections_lr = Projections(
        laminography_projection(projection_angles, fourier_filtering(volume, gf), 20.), projection_angles, fourier_magnitude=False, laminography=True, laminography_tilt_angle=20.
    )
    ptd = PtychoTomographicDiffusion(volume.shape, model)
    ptd.ptycho_tomographic_guidance(
        target_projections_hr, target_projections_lr, lambda x: fourier_filtering(x, gf),
        10, 2, lr=0.1, verbose=True
    )
