import argparse
import torch.nn as nn
import torch
from matplotlib import pyplot as plt
from torchvision.transforms import ToPILImage, ToTensor
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import imageio
from torchvision.transforms import ToPILImage

from sism.mnist.train import rotate_image, get_sigma, ScoreNet, get_diffusion_coefficients
from sism.mnist.mnist_data import MNISTDataModule


def compute_beta(t):
    # TODO: Wrong, must be fixed
    return np.sqrt(3.0) / 2 * t * 180


def smooth_image(image):
    # Define a 3x3 box blur kernel
    kernel = torch.tensor([[1.0, 1.0, 1.0],
                           [1.0, 8.0, 1.0],
                           [1.0, 1.0, 1.0]]) / 16.0

    # Add two extra dimensions to the kernel tensor to match the shape expected by conv2d
    kernel = kernel[None, None, :, :]

    # # Ensure the image has a channel dimension and convert it to float
    # image = image.unsqueeze(0).float()

    # Apply the convolution
    smoothed = F.conv2d(image, kernel, padding=1)

    return smoothed.squeeze(0)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Sampling settings")

    parser.add_argument('--use_conv', action='store_true', default=True)
    parser.add_argument('--display_images', action='store_true', default=False)
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--time_scheduler', action='store_true', default=False)

    args = parser.parse_args()

    stochastic = False
    odeflow = False

    use_time_scheduler = args.time_scheduler

    device = torch.device(args.device)
    use_conv = args.use_conv

    if not use_conv:
        in_dim = 784
        hidden_dim = 32
        score_net = nn.Sequential(nn.Linear(in_dim + 1, hidden_dim),
                                  nn.SiLU(),
                                  nn.Linear(hidden_dim, hidden_dim),
                                  nn.SiLU(),
                                  nn.Linear(hidden_dim, 1),
                                  ).to(device)
    else:
        score_net = ScoreNet().to(device)

    score_net.load_state_dict(torch.load('model_epoch_99_T=100.ckpt'))

    datamodule = MNISTDataModule(affine=False)
    datamodule.setup()

    num_samples = 10
    T = 10
    chain = np.arange(T)

    dataloader_val = datamodule.val_dataloader()

    angle_parameter = torch.randn_like(torch.empty(10, dtype=torch.float))

    lambda_T = angle_parameter * (180 / 2)

    lambda_T = torch.where(lambda_T.abs() < 180 / 8, torch.ones_like(lambda_T) * 180 / 8, lambda_T)

    xsampled = torch.stack([dataloader_val.dataset[i][0] for i in range(10)])

    original_original = xsampled.clone()

    xsampled = torch.stack(
        [ToTensor()(rotate_image(ToPILImage()(image), angle)) for image, angle in
         zip(xsampled, lambda_T)]).squeeze()
    original_xsampled = xsampled.clone()



    x_traj = []
    # xnumpy = x.detach().cpu()

    A = torch.tensor([[0, -1], [1, 0]])
    Delta_t = - 1.0 / T
    betas = get_diffusion_coefficients(T=100, kind="linear-time")

    with torch.no_grad():
        for i, t in tqdm(enumerate(reversed(chain)), total=T):
            t = torch.tensor([t] * num_samples)

            temb = (t / T).unsqueeze(-1)

            betas = compute_beta(temb).squeeze()
            sigma = get_sigma(temb).squeeze()

            if use_time_scheduler:
                Delta_t = - betas[t]

            # variance = get_variance(temb).squeeze()

            if not use_conv:
                xsampled_input = torch.cat([xsampled.view(xsampled.size(0), -1).squeeze(), temb], dim=1)
                scores = score_net(xsampled_input)
            else:

                scores = score_net(xsampled.unsqueeze(1), temb).squeeze()

            z_t = torch.randn_like(torch.empty(10, dtype=torch.float))

            update_thetas = - betas * scores * Delta_t
            if stochastic:
                update_thetas += torch.sqrt(betas * abs(Delta_t)) * z_t
            elif odeflow:
                update_thetas = - 0.5 * betas * scores * Delta_t / sigma

            # xsampled = torch.stack(
            #     [ToTensor()(rotate_image(ToPILImage()(image), - beta * score * Delta_t)) for image, score, beta in
            #      zip(xsampled, scores, betas)]).squeeze()

            xsampled = torch.stack(
                [ToTensor()(rotate_image(ToPILImage()(image), theta)) for image, theta in
                 zip(xsampled, update_thetas)]).squeeze()

            # xsampled = sharpen_image(smooth_image(xsampled.squeeze().unsqueeze(1)).squeeze().unsqueeze(1)).squeeze()
            xsampled = smooth_image(xsampled.squeeze().unsqueeze(1)).squeeze()
            # xsampled = torch.where(xsampled < 0.4, torch.zeros_like(xsampled), torch.ones_like(xsampled))
            x_traj.append(xsampled)

    # xsampled = normalize_batch(xsampled).squeeze()
    # xsampled = sharpen_image(xsampled.squeeze().unsqueeze(1)).squeeze()

    # Select the first 5 samples
    num_samples_to_display = 10
    final_xsampled = xsampled[:num_samples_to_display]
    original_xsampled_np = original_xsampled[:num_samples_to_display].detach().cpu().numpy().squeeze()
    final_xsampled_np = final_xsampled.detach().cpu().numpy().squeeze()
    original_original_np = original_original[:num_samples_to_display].detach().cpu().numpy().squeeze()

    # Create a figure with 3 columns for each sample
    fig, axs = plt.subplots(num_samples_to_display, 3, figsize=(10, 2 * num_samples_to_display))

    for i in range(num_samples_to_display):
        axs[i, 0].imshow(original_xsampled_np[i], cmap='gray')  # Reshape to 28x28 for MNIST
        axs[i, 0].set_title('Original xsampled')
        axs[i, 0].axis('off')

        axs[i, 1].imshow(final_xsampled_np[i], cmap='gray')  # Reshape to 28x28 for MNIST
        axs[i, 1].set_title('Final xsampled')
        axs[i, 1].axis('off')

        axs[i, 2].imshow(original_original_np[i], cmap='gray')  # Reshape to 28x28 for MNIST
        axs[i, 2].set_title('Original unrotates')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

    # final_xsampled = xsampled[0]
    #
    # Convert each tensor in x_traj to a PIL Image and then to a numpy array
    for i in range(num_samples_to_display):
        images = [np.array(img[i] * 255) for img in x_traj]

        # Save the images as a GIF
        imageio.mimsave(f'trajectory_{i}.gif', images, duration=2.0)
    #
    # # Convert the tensors to numpy arrays
    # original_xsampled_np = original_xsampled[0].detach().cpu().numpy().squeeze()
    # final_xsampled_np = final_xsampled.detach().cpu().numpy().squeeze()
    # original_original_np = original_original[0].detach().cpu().numpy().squeeze()
    #
    # # Plot the original and final xsampled
    # fig, axs = plt.subplots(1, 3, figsize=(10, 2))
    #
    # axs[0].imshow(original_xsampled_np, cmap='gray')  # Reshape to 28x28 for MNIST
    # axs[0].set_title('Original xsampled')
    # axs[0].axis('off')
    #
    # axs[1].imshow(final_xsampled_np, cmap='gray')  # Reshape to 28x28 for MNIST
    # axs[1].set_title('Final xsampled')
    # axs[1].axis('off')
    #
    # axs[2].imshow(original_original_np, cmap='gray')  # Reshape to 28x28 for MNIST
    # axs[2].set_title('Original unrotates')
    # axs[2].axis('off')
    #
    # plt.show()
