import torch
import numpy as np
import os
from PIL import Image
from fused_ssim import fused_ssim

gt_image = torch.tensor(np.array(Image.open(os.path.join("..", "images", "albert.jpg"))), dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) / 255.0
pred_image = torch.nn.Parameter(torch.rand_like(gt_image))

with torch.no_grad():
    ssim_value = fused_ssim(pred_image, gt_image, train=False)
    print("Starting with SSIM value:", ssim_value)


optimizer = torch.optim.Adam([pred_image])

while ssim_value < 0.9999:
    optimizer.zero_grad()
    loss = 1.0 - fused_ssim(pred_image, gt_image)
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        ssim_value = fused_ssim(pred_image, gt_image, train=False)
        print("SSIM value:", ssim_value)

pred_image = (pred_image * 255.0).squeeze(0).squeeze(0)
to_save = pred_image.detach().cpu().numpy().astype(np.uint8)
Image.fromarray(to_save).save(os.path.join("..", "images", "predicted.jpg"))
