import math
import random

import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm.auto import tqdm

from diffusers import UNet2DModel
from torch import optim

from chip.utils import create_gaussian_filter
from chip.utils.fourier import fft_2D, ifft_2D
import torchvision.transforms.functional as TF

b_fft_2D = torch.vmap(lambda x: fft_2D(x, ortho=True))
b_ifft_2D = torch.vmap(lambda x: ifft_2D(x, ortho=True))

bb_fft_2D = torch.vmap(b_fft_2D)
bb_ifft_2D = torch.vmap(b_ifft_2D)


class PtychographicDiffusion(nn.Module):
    def __init__(
            self, image_shape, unet: UNet2DModel,
            probe, num_probes_per_row,
            probe_positions=None,
            prior=None,
            magnitude=None,
            shifts_padding=50, buffer=5,
            with_circle_mask: bool = True
    ):
        super(PtychographicDiffusion, self).__init__()
        self.computed_image = None
        self.unet = unet
        self.bs = image_shape[0]
        self.buffer = buffer
        self.alpha = 0.
        self.magnitude = nn.Parameter(magnitude)
        self.magnitude.requires_grad = False

        if prior is None:
            self.x_t = nn.Parameter(torch.zeros(image_shape).to(torch.complex64))
        else:
            assert (prior.shape[-2:] == image_shape[-2:])
            self.x_t = nn.Parameter(prior.clone().unsqueeze(0).to(torch.complex64))
        if probe_positions is None:
            shifts = torch.linspace(0 + shifts_padding, 512 - +shifts_padding, math.floor(num_probes_per_row))
            probe_positions = torch.cartesian_prod(shifts, shifts) - 256
            self.num_probes = num_probes_per_row ** 2
        else:
            self.num_probes = len(probe_positions)

        shifted_probes = self.compute_shifted_probes(probe, probe_positions)

        self.register_buffer("probe_positions", probe_positions)
        self.register_buffer("shifted_probes", shifted_probes)
        w = image_shape[-1]
        cp = torch.cartesian_prod(torch.arange(w), torch.arange(w))
        self.circle_mask = (cp[:, 0] - w / 2) ** 2 + (cp[:, 1] - w / 2) ** 2 < (w / 2) ** 2
        self.circle_mask = self.circle_mask.reshape(w, w)
        if not with_circle_mask:
            self.circle_mask = False

    def compute_shifted_probes(self, probe, positions):
        transformed_images = []
        for (dx, dy) in positions:
            transformed_images.append(
                TF.affine(probe.unsqueeze(0), angle=0, translate=(dx, dy), scale=1, shear=0))
        shifted_probes = torch.cat(transformed_images, 0)
        center = shifted_probes.shape[1] // 2
        # cropping into the center
        half_width = self.x_t.shape[-1] // 2
        shifted_probes = shifted_probes[:, center - half_width:center + half_width,
                         center - half_width:center + half_width]
        return shifted_probes

    def get_img(self):
        with torch.no_grad():
            image = self.x_t
            # image.view(self.bs, -1)[:, self.circle_mask] = 0
            image *= self.circle_mask[None, :].to(image.device)
            self.computed_image = image
        return image

    def predict_x_0(self, t, x_t, noise_scheduler):
        device = self.x_t.device
        timesteps = torch.LongTensor([t]).to(device)
        noise_pred = self.unet(x_t, timesteps, return_dict=False)[0]
        fading_factor = noise_scheduler.add_noise(torch.ones(1), torch.zeros(1), timesteps).to(device)
        noise_factor = noise_scheduler.add_noise(torch.zeros(1), torch.ones(1), timesteps).to(device)
        x_0_pred = (x_t - noise_factor[:, None, None, None] * noise_pred) / fading_factor[:, None, None, None]
        return noise_pred, x_0_pred

    def step(self, t, target_t, x_t, noise_scheduler):
        device = self.x_t.device
        noise_pred, x_0_pred = self.predict_x_0(t, x_t, noise_scheduler)
        new_timestep = torch.LongTensor([target_t]).to(device)
        new_x_t = noise_scheduler.add_noise(x_0_pred, torch.randn_like(x_0_pred), new_timestep).to(device)

        return new_x_t, noise_pred

    def diffusion_pipeline(self, x_t_start, t_start, t_end, noise_scheduler, num_steps=50, verbose=False):
        device = self.x_t.device
        with torch.no_grad():
            x_t = x_t_start.clone().real

            timesteps = torch.linspace(t_start, t_end, num_steps + 1).int()
            for i in tqdm(range(1, len(timesteps)), disable=not verbose):
                t = timesteps[i - 1]
                target_t = timesteps[i]
                x_t, _ = self.step(t, target_t, x_t, noise_scheduler)
            return x_t

    def guided_diffusion_pipeline(
            self, x_t_start, t_start, t_end, noise_scheduler, num_steps=50,
            probe_indices=None, target_magnitudes=None, batch_size=10, verbose=False, sgd_steps=50, lr=0.1,
            magnitude_inpainting: bool = False, sample_randomly=False
    ):
        device = self.x_t.device
        if probe_indices is None:
            probe_indices = torch.arange(len(self.shifted_probes))

        x_t = x_t_start.clone()

        sgd_steps = 0 if target_magnitudes is None or len(probe_indices) == 0 else sgd_steps
        if sgd_steps == 0:
            print("using no SGD steps")

        timesteps = torch.linspace(t_start, t_end + self.buffer, num_steps + 1).int()
        for i in tqdm(range(1, len(timesteps)), disable=not verbose):
            t = timesteps[i - 1]
            target_t = timesteps[i]

            with torch.no_grad():
                noise_pred, x_0_pred = self.predict_x_0(t, x_t, noise_scheduler)
                self.x_t *= 0
                self.x_t += x_0_pred[:, 0]

            if magnitude_inpainting:
                for _ in range(5):
                    self.ptychographic_magnitude_inpainting(
                        target_magnitudes, probe_indices,
                        sample_randomly=sample_randomly
                    )

            if sgd_steps > 0:
                self.ptychographic_guidance(
                    target_magnitudes, probe_indices, sgd_steps, batch_size, lr,
                    verbose=False
                )

            if len(self.magnitude) > 0:
                self.fourier_magnitude_step()

            new_timestep = torch.LongTensor([target_t]).to(device)
            x_t = noise_scheduler.add_noise(self.get_img().unsqueeze(1).real, torch.randn_like(x_0_pred),
                                            new_timestep).to(
                device)

        x_t = self.diffusion_pipeline(
            x_t, t_end + self.buffer, t_end,
            noise_scheduler, num_steps=self.buffer,
            verbose=verbose
        )

        return x_t

    def fourier_magnitude_step(self):
        with torch.no_grad():
            fourier_ptychogram = bb_fft_2D(self.x_t)
            phase = torch.angle(fourier_ptychogram)
            complex_representation = self.magnitude * (torch.cos(phase) + 1j * torch.sin(phase))
            self.x_t *= 0
            self.x_t += bb_ifft_2D(complex_representation).real

    def ptychographic_magnitude_inpainting(
            self,
            target_magnitudes,
            indices,
            sample_randomly: bool = False
    ):
        if sample_randomly:
            indices = indices[random.randint(0, len(indices) - 1)].unsqueeze(0)
        with torch.no_grad():
            all_indices = torch.arange(len(self.shifted_probes))
            predicted_ptychograms = self.forward(probe_indices=all_indices)
            fourier_ptychograms = bb_fft_2D(predicted_ptychograms)
            predicted_phase = torch.angle(fourier_ptychograms)
            expanded_target = torch.abs(fourier_ptychograms)
            expanded_target[:, indices] = target_magnitudes[indices].clone()

            complex_recons = expanded_target * (torch.cos(predicted_phase) + 1j * torch.sin(predicted_phase))
            recons = bb_ifft_2D(complex_recons) * self.shifted_probes
            recons /= torch.sum(self.shifted_probes ** 2, dim=0)

            # deal with the divisions by zero
            recons[torch.isinf(recons)] = 0
            recons[torch.isnan(recons)] = 0
            # non_zero_mask = self.shifted_probes != 0
            # contributions = torch.sum(non_zero_mask, dim=0)
            # contributions[contributions == 0] = 1
            avg = torch.sum(recons, dim=1)  # / contributions[None, :]
            avg *= self.circle_mask.to(avg.device)
            self.x_t *= self.alpha
            self.x_t += (1 - self.alpha) * avg

    def ptycho_loss(self, target, indices):
        mse = torch.nn.MSELoss()

        predicted_ptychograms = self(probe_indices=indices)
        fourier_ptychograms = bb_fft_2D(predicted_ptychograms)
        predicted_magnitudes = torch.abs(fourier_ptychograms)

        predicted_phase = torch.angle(fourier_ptychograms)

        loss = mse(
            target.expand(len(predicted_magnitudes), *target.shape).reshape(-1),
            predicted_magnitudes.reshape(-1)
        )

        complex_recons = target * (torch.cos(predicted_phase) + 1j * torch.sin(predicted_phase))
        missmatch_recons = bb_ifft_2D(complex_recons)
        loss = mse(missmatch_recons.real, predicted_ptychograms.real)
        loss += mse(missmatch_recons.imag, predicted_ptychograms.imag)
        return loss

    def ptychographic_guidance(
            self, target_magnitudes, probe_indices,
            sgd_steps=50,
            batch_size=10,
            lr=0.1, verbose=False
    ):

        device = self.x_t.device

        # we use the target model to create the target sinograms
        target_magnitudes = target_magnitudes.to(device) if target_magnitudes is not None else None

        optimizer = optim.AdamW(
            self.parameters(),
            lr=lr
        )
        loss = torch.tensor(0.).to(device)
        iterator = tqdm(range(sgd_steps), disable=not verbose)
        for _ in iterator:
            self.zero_grad()
            random_integers = torch.randperm(len(probe_indices))[:batch_size].int()

            indices = probe_indices[random_integers]
            target = target_magnitudes[indices]

            loss = self.ptycho_loss(target, indices)

            iterator.set_postfix({"loss": f"{loss.item():.8f}"})

            loss.backward()
            optimizer.step()

        return loss.item()

    def forward(self, probe_indices=None, image=None):
        with torch.no_grad():
            if image is None:
                image = self.get_img()
        if probe_indices is None:
            return torch.stack([probe * image for probe in self.shifted_probes], 1)
        else:
            return torch.stack([self.shifted_probes[i] * image for i in probe_indices], 1)


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    model = UNet2DModel(
        sample_size=512,  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )


    def get_spiral(num_samples, wrapping_number):
        # Generate theta values
        theta_positive = torch.linspace(0, wrapping_number * torch.pi, num_samples)

        # Calculate the radius for both positive and negative theta
        r_positive = torch.sqrt(theta_positive)

        # Convert polar coordinates to Cartesian for plotting
        x_positive = r_positive * torch.cos(theta_positive)
        y_positive = r_positive * torch.sin(theta_positive)
        return torch.stack([x_positive, y_positive], -1)


    spiral = 155 * get_spiral(num_samples=25, wrapping_number=7)

    bs = 1
    gaussian_filter = create_gaussian_filter(size=2048, sigma=50)
    pd = PtychographicDiffusion(
        (bs, 512, 512), model, probe_positions=spiral,
        probe=gaussian_filter,
        num_probes_per_row=5, prior=torch.randn(512, 512),
        buffer=0, alternate_loss=True
    )

    pd.ptychographic_magnitude_inpainting(torch.randn(25, 512, 512), indices=torch.arange(25))

    # print(pd.forward().real)
    #
    # pd.ptychographic_guidance(
    #     torch.randn(25, 512, 512), probe_indices=torch.arange(25),
    #     sgd_steps=0,
    #     batch_size=10,
    #     lr=.1, verbose=True,
    # )

    # from diffusers import DDPMScheduler
    #
    # noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    # # img = td.diffusion_pipeline(torch.randn(bs, 512, 512).unsqueeze(1), 999, 0, noise_scheduler, 2)
    # img = pd.guided_diffusion_pipeline(
    #     torch.randn(bs, 1, 512, 512), 999, 0, noise_scheduler, 2,
    #     probe_indices=None, target_magnitudes=torch.randn(25, 512, 512), batch_size=10,
    #     verbose=True, sgd_steps=0, lr=0.1,
    #     magnitude_inpainting=True
    # )
    # print(img)
