import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from tqdm.auto import tqdm

from diffusers import UNet2DModel
from torch import optim
from chip.utils.fourier import fft_2D, ifft_2D
from chip.gridrec.gridRecAUS import gridRec, gridRecAUS
from chip.utils.sinogram import compute_sinogram, Sinogram
from chip.utils.utils import create_circle_filter, find_match_indices


class TomographicDiffusion(nn.Module):
    def __init__(self, image_shape, unet: UNet2DModel, use_sigmoid: bool = True, buffer=5, fourier_magnitude=None):
        super(TomographicDiffusion, self).__init__()
        self.computed_image = None
        self.unet = unet
        self.x_t = nn.Parameter(torch.zeros(image_shape))
        self.use_sigmoid = use_sigmoid
        self.buffer = buffer
        self.fourier_magnitude = fourier_magnitude

    def get_img(self):
        if self.use_sigmoid:
            image = 5 * (self.x_t - 0.5)
            image = torch.sigmoid(image)
        else:
            image = self.x_t

        self.computed_image = image
        return image

    def predict_x_0(self, t, x_t, noise_scheduler):
        device = self.x_t.device
        timesteps = torch.LongTensor([t]).to(device)
        noise_pred = self.unet(x_t, timesteps, return_dict=False)[0]
        fading_factor = noise_scheduler.add_noise(torch.ones(1), torch.zeros(1), timesteps).to(device)
        noise_factor = noise_scheduler.add_noise(torch.zeros(1), torch.ones(1), timesteps).to(device)
        x_0_pred = (x_t - noise_factor[:, None, None, None] * noise_pred) / fading_factor[:, None, None, None]
        return noise_pred, x_0_pred

    def step(self, t, target_t, x_t, noise_scheduler):
        device = self.x_t.device
        noise_pred, x_0_pred = self.predict_x_0(t, x_t, noise_scheduler)
        new_timestep = torch.LongTensor([target_t]).to(device)
        new_x_t = noise_scheduler.add_noise(x_0_pred, torch.randn_like(x_0_pred), new_timestep).to(device)

        return new_x_t, noise_pred

    def diffusion_pipeline(self, x_t_start, t_start, t_end, noise_scheduler, num_steps=50, verbose=False):
        device = self.x_t.device
        with torch.no_grad():
            x_t = x_t_start.clone()

            timesteps = torch.linspace(t_start, t_end, num_steps + 1).int()
            for i in tqdm(range(1, len(timesteps)), disable=not verbose):
                t = timesteps[i - 1]
                target_t = timesteps[i]
                x_t, _ = self.step(t, target_t, x_t, noise_scheduler)
            return x_t

    def guided_diffusion_pipeline(
            self, x_t_start, t_start, t_end, noise_scheduler, num_steps=50,
            hr_sinogram: Sinogram = None,
            lr_sinogram: Sinogram = None,
            lr_forward_function=None,
            batch_size=10, verbose=False, sgd_steps=50, lr=0.1,
            with_finetuning: bool = False,
            # fourier_inpainting: bool = False,
            # inpainting_range=0
    ):
        device = self.x_t.device
        hr_sinogram = hr_sinogram.to(device) if hr_sinogram is not None else None
        x_t = x_t_start.clone()
        # the lr_forward function should be batched
        lr_forward_function = torch.vmap(lr_forward_function)

        timesteps = torch.linspace(t_start, t_end + self.buffer, num_steps + 1).int()
        for i in tqdm(range(1, len(timesteps)), disable=not verbose):
            t = timesteps[i - 1]
            target_t = timesteps[i]

            with torch.no_grad():
                noise_pred, x_0_pred = self.predict_x_0(t, x_t, noise_scheduler)
                # if fourier_inpainting and t > inpainting_range:
                #     if self.fourier_magnitude is not None:
                #         x_0_pred = self.fourier_magnitude_step(x_0_pred.clone(), self.fourier_magnitude)
                #     else:
                #         x_0_pred = self.fourier_step(x_0_pred.clone(), lr_tomogram, circle_filter)
                self.x_t *= 0
                self.x_t += x_0_pred[:, 0]

            if hr_sinogram is not None or lr_sinogram is not None:
                self.sinogram_guidance(
                    hr_sinogram, lr_sinogram, lr_forward_function, sgd_steps=sgd_steps, batch_size=batch_size, lr=lr
                )

            new_timestep = torch.LongTensor([target_t]).to(device)
            x_t = noise_scheduler.add_noise(self.get_img().unsqueeze(1), torch.randn_like(x_0_pred), new_timestep).to(
                device)

        if with_finetuning:
            with torch.no_grad():
                self.x_t *= 0
                self.x_t += x_t[:, 0]

            self.sinogram_guidance(
                hr_sinogram, lr_sinogram, sgd_steps=[500, 200, 200], batch_size=batch_size, lr=[0.1, 0.01, 0.001], verbose=verbose
            )
            x_t = self.get_img().unsqueeze(1)

        x_t = self.diffusion_pipeline(
            x_t, t_end + self.buffer, t_end,
            noise_scheduler, num_steps=self.buffer,
            verbose=verbose
        )

        return x_t

    def fourier_step(
            self,
            tomogram,
            lr_tomogram,
            circle_filter
    ):
        fourier_lr = fft_2D(lr_tomogram)
        fourier_hr = fft_2D(tomogram)
        fourier_hr[:, :, circle_filter.bool()] = fourier_lr[circle_filter.bool()]
        return ifft_2D(fourier_hr).real

    def gridrec_step(
            self,
            current_tomogram,
            lr_tomogram,
            circle_filter,
            magnitude_threshold,
            target_sinogram,
            sinogram_angles,
            doPostCorrect
    ):
        with torch.no_grad():
            device = self.x_t.device
            sinogram_radians = torch.deg2rad(sinogram_angles)
            fft_x_t = fft_2D(current_tomogram)

            rec = torch.FloatTensor(
                gridRec(target_sinogram.cpu().numpy(), sinogram_radians.cpu().numpy(),
                        np.arange(len(sinogram_angles)).astype(np.int32), doPostCorrect=doPostCorrect,
                        zeroPaddingFactor=0)
            ).to(device)
            fft_rec = fft_2D(rec)

            fft_lr_tomogram = fft_2D(lr_tomogram)

            mask = (torch.abs(fft_rec) > magnitude_threshold).to(device)

            fft_x_t[:, :, mask] = fft_rec[mask]
            fft_x_t[:, :, circle_filter.bool()] = fft_lr_tomogram[circle_filter.bool()]

            return ifft_2D(fft_x_t).real

    def fourier_magnitude_step(
            self,
            tomogram,
            magnitude,
    ):
        fourier_hr = fft_2D(tomogram)
        phase_hr = torch.angle(fourier_hr)

        complex_representation = magnitude * (torch.cos(phase_hr) + 1j * torch.sin(phase_hr))
        return ifft_2D(complex_representation).real

    def sinogram_guidance(
            self, hr_sinogram: Sinogram,
            lr_sinogram: Sinogram,
            lr_forward_function=lambda x: x,
            sgd_steps=50,
            batch_size=10,
            lr=0.1, verbose=False
    ):

        if type(lr) != list:
            lr = [lr]
            sgd_steps = [sgd_steps]

        mse = torch.nn.MSELoss()
        device = self.x_t.device

        # we use the target model to create the target sinograms
        hr_sinogram = hr_sinogram.to(device) if hr_sinogram is not None else None


        for lr_instance, sgd_steps_instance in zip(lr, sgd_steps):
            optimizer = optim.AdamW(
                self.parameters(),
                lr=lr_instance
            )
            loss = torch.tensor(0.).to(device)
            iterator = tqdm(range(sgd_steps_instance), disable=not verbose)
            for _ in iterator:
                self.zero_grad()
                loss = torch.tensor(0.).to(device)

                bs = len(self.x_t)
                if hr_sinogram is not None:
                    random_integers = [x.item() for x in torch.randperm(len(hr_sinogram.angles))[:batch_size]]
                    sinogram = compute_sinogram(self.x_t, hr_sinogram.angles[random_integers])
                    target_sino = hr_sinogram[random_integers]
                    loss += mse(sinogram.reshape(bs, *target_sino.shape), target_sino.expand(bs, *target_sino.shape))

                if lr_sinogram is not None:
                    random_integers = [x.item() for x in torch.randperm(len(lr_sinogram.angles))[:batch_size]]
                    sinogram = compute_sinogram(lr_forward_function(self.x_t), lr_sinogram.angles[random_integers])
                    target_sino = lr_sinogram[random_integers]
                    loss += mse(sinogram.reshape(bs, *target_sino.shape), target_sino.expand(bs, *target_sino.shape))

                iterator.set_postfix({"loss": f"{loss.item():.3f}"})
                loss.backward()
                optimizer.step()

        return loss.item()

    def forward(self, sinogram_angles, filter=None, filter_in_sinogram_space: bool = False):
        device = self.x_t.device
        if not filter_in_sinogram_space:
            images = self.get_img()
        else:
            images = self.get_img()

        rotation_matrix = torch.stack([
            torch.stack([torch.cos(torch.deg2rad(sinogram_angles)), -torch.sin(torch.deg2rad(sinogram_angles)),
                         torch.zeros_like(sinogram_angles, device=device)], 1),
            torch.stack([torch.sin(torch.deg2rad(sinogram_angles)), torch.cos(torch.deg2rad(sinogram_angles)),
                         torch.zeros_like(sinogram_angles, device=device)], 1)
        ], 1)
        current_grid = F.affine_grid(rotation_matrix.to(images.device),
                                     images.repeat(len(sinogram_angles), 1, 1, 1).size(), align_corners=False)

        rotated = F.grid_sample(images.repeat(len(sinogram_angles), 1, 1, 1).float(), current_grid.repeat(1, 1, 1, 1),
                                align_corners=False)
        rotated = rotated.transpose(0, 1)
        # Sum over one of the dimensions to compute the projection
        sinogram = rotated.sum(axis=-2).squeeze(2)
        sinogram = sinogram if len(sinogram) > 1 else sinogram[0]
        if not filter_in_sinogram_space:
            return sinogram
        else:
            return filter(sinogram)


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    model = UNet2DModel(
        sample_size=512,  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    )

    bs = 2
    td = TomographicDiffusion((bs, 512, 512), model, True, buffer=0)

    print(td.forward(torch.linspace(0, 180, 10)))

    from diffusers import DDPMScheduler
    from chip.utils.utils import create_circle_filter, create_gaussian_filter
    from chip.models.forward_models import fourier_filtering

    side_length = 512
    frequency_cut_out_radius = 15
    circle_filter = create_circle_filter(frequency_cut_out_radius, side_length)
    gaussian_filter = create_gaussian_filter(sigma=10, size=side_length)

    current_filter = circle_filter

    lr_forward_function = lambda x: fourier_filtering(x, current_filter)

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    # img = td.diffusion_pipeline(torch.randn(bs, 512, 512).unsqueeze(1), 999, 0, noise_scheduler, 2)
    img = td.guided_diffusion_pipeline(
        torch.randn(bs, 512, 512).unsqueeze(1), 999, 0, noise_scheduler, 2,
        Sinogram(torch.randn(180, 512), torch.arange(180)),
        Sinogram(torch.randn(180, 512), torch.arange(180)),
        lr_forward_function=lr_forward_function,
        batch_size=10, verbose=True, sgd_steps=[2, 10], lr=[0.1, 0.01]
    )
    print(img)
