import os
from argparse import ArgumentParser

import torch

from chip.datasets.superres_dataset import SuperresolutionDS
from chip.datasets.tomogram_dataset import TomogramDataset
from pytorch_base.experiment import PyTorchExperiment
from pytorch_base.base_loss import BaseLoss

import random

from diffusers import UNet2DModel
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup

from torch.utils.data import Dataset
import matplotlib.pyplot as plt

import h5py
import hdf5plugin



def fft_2D(x):
    fft_image = torch.fft.fft2(x)
    return torch.fft.fftshift(fft_image, dim=[-1, -2])

def ifft_2D(x):
    f_space = torch.fft.ifftshift(x, dim=[-1, -2])
    # Perform inverse 2D Fourier Transform
    return torch.fft.ifft2(f_space)

b_fft_2D = torch.vmap(fft_2D)
b_ifft_2D = torch.vmap(ifft_2D)

def log_compression(x):
    x = x.clone()
    log_x_real = torch.log(torch.abs(x.real) + 1)
    log_x_imag = torch.log(torch.abs(x.imag) + 1)
    x.real = x.real * log_x_real / torch.abs(x.real)
    x.imag = x.imag * log_x_imag / torch.abs(x.imag)
    x.real[torch.isnan(x.real)] = 0.
    x.imag[torch.isnan(x.imag)] = 0.
    return x


def recover_spatial_domain(x):
    x = x.clone()
    x *= 24000
    comp = torch.complex(x[:, 0], x[:, 1])
    return b_ifft_2D(comp)


def normalize_fourier(tomogram):
    fft_tomogram = b_fft_2D(tomogram)
    fft_tomogram /= 24000

    return torch.cat([fft_tomogram.real, fft_tomogram.imag], 1)


def normalize_fourier(tomogram):
    fft_tomogram = b_fft_2D(tomogram)
    sign_real = torch.sign(fft_tomogram.real)
    sign_imag = torch.sign(fft_tomogram.imag)
    log_real = sign_real * torch.log(1 + torch.abs(fft_tomogram.real)) / 10.
    log_imag = sign_imag * torch.log(1 + torch.abs(fft_tomogram.imag)) / 10.

    return torch.cat([log_real, log_imag], 1)

class diffusion_loss(BaseLoss):

    def __init__(self):
        stats_names = ["loss", "loss_x_0"]
        super(diffusion_loss, self).__init__(stats_names)

    def compute_loss(self, instance, model):
        mse = torch.nn.MSELoss()
        _, tomogram, file = instance
        tomogram = tomogram.to(device)
        with torch.no_grad():
            # We stack the real an imaginary part into a (bs, 2, im_size, im_size) tensor
            x_0 = normalize_fourier(tomogram).to(device)

        noise = torch.randn_like(x_0).to(device)
        bs = x_0.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bs,), device=x_0.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        x_t = noise_scheduler.add_noise(x_0, noise, timesteps)

        model.zero_grad()
        noise_pred = model(x_t, timesteps, return_dict=False)[0]

        loss_noise = mse(noise_pred, noise)

        fading_factor = noise_scheduler.add_noise(torch.ones(1), torch.zeros(1), timesteps).to(device)
        noise_factor = noise_scheduler.add_noise(torch.zeros(1), torch.ones(1), timesteps).to(device)

        x_0_pred = (x_t - noise_factor[:, None, None, None] * noise_pred) / fading_factor[:, None, None, None]
        tomo_pred = recover_spatial_domain(x_0_pred)


        loss_x_0 = mse(tomo_pred.real.reshape(-1), tomogram.reshape(-1))
        loss_x_0 += mse(tomo_pred.imag, torch.zeros_like(tomo_pred.imag))

        loss = loss_noise + loss_x_0
        return loss, {"loss": loss_noise, "loss_x_0": loss_x_0}


def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"models loaded from checkpoint {model_path}")


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    parser = ArgumentParser(description="PyTorch experiments")
    parser.add_argument("--batch_size", default=50, type=int, help="batch size of every process")
    parser.add_argument("--epochs", default=1001, type=int, help="number of epochs to train")
    parser.add_argument("--learning_rate", default=0.0001, type=float, help="learning rate")
    parser.add_argument("--weight_decay", default=0.001, type=float, help="weight decay")
    parser.add_argument("--scheduler", default="[500]", type=str, help="scheduler decrease after epochs given")
    parser.add_argument("--lr_decay", default=0.1, type=float, help="Learning rate decay")
    parser.add_argument("--wandb_exp_name", default='random_experiments', type=str, help="Experiment name in wandb")
    parser.add_argument('--wandb', action='store_true', help="Use wandb")
    parser.add_argument("--load_checkpoint", default='', type=str, help="name of models in folder checkpoints to load")
    parser.add_argument("--seed", default=-1, type=int, help="Random seed")
    args = vars(parser.parse_args())
    temp = args["scheduler"].replace(" ", "").replace("[", "").replace("]", "").split(",")
    args["scheduler"] = [int(x) for x in temp]
    args["seed"] = random.randint(0, 20000) if args["seed"] == -1 else args["seed"]

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # device = torch.device('cpu')
    print(device)

    # previous dataset class
    # files = os.listdir('data/imgs_synthetic')
    # ds = SuperresolutionDS(files, data_path="data")
    #
    # trainSet_raw = torch.utils.data.Subset(ds, range(0, round(0.9 * len(ds))))
    # testSet = torch.utils.data.Subset(ds, range(round(0.9 * len(ds)), len(ds)))
    # trainSet = TomogramDataset(trainSet_raw)

    # new dataset class
    data_path = "/mydata/chip/shared/data/tomograms_blueprint.h5" if torch.cuda.is_available() else "data/tomograms_blueprint.h5"
    dataset = TomogramDataset(data_path, train_transform=True)
    testSet = torch.utils.data.Subset(dataset, range(round(0.9 * len(dataset)), len(dataset)))
    trainSet = torch.utils.data.Subset(dataset, range(0, round(0.9 * len(dataset))))


    model_path = f"checkpoints/diffusion_model_fourier.pt"

    model = UNet2DModel(
        sample_size=512,  # the target image resolution
        in_channels=2,  # the number of input channels, 3 for RGB images
        out_channels=2,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "DownBlock2D",
        ),
        up_block_types=(
            "UpBlock2D",  # a regular ResNet upsampling block
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
    ).to(device)

    if args['load_checkpoint'] != "":
        try:
            load_model(model, f"{args['load_checkpoint']}")
        except:
            print("model not found, initializing randomly")

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

    exp = PyTorchExperiment(
        train_dataset=trainSet,
        test_dataset=testSet,
        batch_size=args['batch_size'],
        model=model,
        loss_fn=diffusion_loss(),
        checkpoint_path=model_path,
        experiment_name=args['wandb_exp_name'],
        with_wandb=args['wandb'],
        num_workers=torch.get_num_threads() if torch.cuda.is_available() else 0,
        seed=args["seed"],
        args=args
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=args['learning_rate'])

    num_epochs = 50
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=500,
        num_training_steps=(len(trainSet) * num_epochs),
    )

    exp.train(args['epochs'], optimizer, milestones=args['scheduler'], gamma=args['lr_decay'], scheduler=lr_scheduler)
