import os
import random
import lovely_tensors as lt

from chip.datasets.superres_dataset import SuperresolutionDS
from chip.models.iterative_model import TomographicReconstruction
from chip.utils import create_gaussian_filter
from chip.utils.utils import load_model
from torch.utils.data import Dataset
from torchvision import transforms
from chip.utils.fourier import fft_2D, ifft_2D

lt.monkey_patch()
import torch

from diffusers import UNet2DModel
from diffusers import DDPMScheduler

from chip.models.tomographic_diffusion import TomographicDiffusion


class RotatedTomograms(Dataset):
    def __init__(self, data, gaussian_filter):
        """
        Args:
            data (list or array): Your data.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.gaussian_filter = gaussian_filter
        self.data = data
        self.train_transform = transforms.Compose(
            [
                transforms.RandomAffine((-180, 180), (0, 0), (0.6, 1.), interpolation=InterpolationMode.BILINEAR),
                # transforms.ToTensor(),
            ]
        )

    def test_transform(self):
        return transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, y, file = self.data[idx]

        sample = self.train_transform(torch.stack([x, y], 0))

        sample[0] = ifft_2D(fft_2D(sample[1]) * self.gaussian_filter).real

        return sample[0], sample[1], file

if __name__ == '__main__':
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    model = UNet2DModel(
        sample_size=512,  # the target image resolution
        in_channels=1,  # the number of input channels, 3 for RGB images
        out_channels=1,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    model_path = f"checkpoints/diffusion_model_tomogram.pt"
    load_model(model, model_path)
    model.eval()

    files = os.listdir('data/imgs_synthetic')
    ds = SuperresolutionDS(files, data_path="data")

    gaussian_filter = create_gaussian_filter(size=512, sigma=20)
    trainSet = RotatedTomograms(ds, gaussian_filter)

    index = random.randint(0, len(ds) - 1)
    source, target, _ = trainSet[index]

    sinogram_angles = torch.linspace(0, 179, 10)
    target_model = TomographicReconstruction(target, False)
    with torch.no_grad():
        target_sinogram = target_model(sinogram_angles).to(device) if len(sinogram_angles) > 0 else None

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    t_start = 999

    # produces noise and uses the scheduler to scale it appropiately
    x_t = noise_scheduler.add_noise(torch.zeros_like(target), torch.randn_like(target), torch.LongTensor([t_start])).to(device).unsqueeze(
        0).to(device)

    bs = 2 # number of images to be produced by the diffusion model
    td = TomographicDiffusion((bs, 512, 512), model, use_sigmoid=True).to(device)

    images = td.guided_diffusion_pipeline(
        x_t.repeat(bs, 1, 1, 1), t_start, 0, noise_scheduler, 50,
        target_sinogram, sinogram_angles.to(device),
        batch_size=min(30, len(sinogram_angles)),
        verbose=True, sgd_steps=50, lr=0.1,
        with_finetuning=True,
        fourier_inpainting=True,
        lr_tomogram=source.repeat(bs, 1, 1, 1).to(device), gaussian_filter=gaussian_filter.to(device),
        fourier_threshold=1e-1
    )
    print(images)