import argparse

import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from models import Siren, STAFNet
from utils import ImageFitting

# Check if CUDA (GPU) is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")


parser = argparse.ArgumentParser(description="Image reconstruction parameters")

parser.add_argument(
    "-e",
    "--epochs",
    type=int,
    help="Number of epochs for training",
    default=250,
)

parser.add_argument(
    "-nf",
    "--number_frequencies",
    type=int,
    help="Number of epochs for training",
    default=25,
)
args = parser.parse_args()

imsize = 512
nf = args.number_frequencies
total_steps = args.epochs
steps_til_summary = 1

cameraman = ImageFitting(imsize)
dataloader = DataLoader(
    cameraman, batch_size=1, pin_memory=False, num_workers=0, shuffle=False
)

img_staff = STAFNet(
    in_features=2,
    hidden_features=256,
    hidden_layers=3,
    out_features=1,
    nf=nf,
    outermost_linear=True,
)

img_siren = Siren(
    in_features=2,
    out_features=1,
    hidden_features=256,
    hidden_layers=3,
    outermost_linear=True,
)

img_siren.to(device)
img_staff.to(device)


optim = torch.optim.Adam(
    lr=1e-4, betas=(0.9, 0.999), eps=1e-08, params=img_siren.parameters()
)
optim_para = torch.optim.Adam(
    lr=1e-4, betas=(0.9, 0.999), eps=1e-08, params=img_staff.parameters()
)
scheduler = LambdaLR(optim, lambda x: 0.1 ** min(x / total_steps, 1))
para_scheduler = LambdaLR(optim_para, lambda x: 0.1 ** min(x / total_steps, 1))

model_input, ground_truth = next(iter(dataloader))
model_input, ground_truth = model_input.to(device), ground_truth.to(device)


psnr = []
psnr_staf = []

for step in range(total_steps):
    model_output, coords = img_siren(model_input)
    model_output_para, coords_para = img_staff(model_input)
    loss = ((model_output - ground_truth) ** 2).mean()
    loss_para = ((model_output_para - ground_truth) ** 2).mean()

    optim.zero_grad()
    optim_para.zero_grad()
    loss.backward()
    loss_para.backward()
    optim.step()
    optim_para.step()
    scheduler.step()
    para_scheduler.step()

    with torch.no_grad():
        outim = model_output.cpu().detach().view(-1, 1)
        outim_para = model_output_para.cpu().detach().view(-1, 1)
        psnr.append(10 * torch.log10(1 / torch.mean((outim - cameraman.pixels) ** 2)))
        psnr_staf.append(
            10 * torch.log10(1 / torch.mean((outim_para - cameraman.pixels) ** 2))
        )

    if not step % steps_til_summary:
        print("Step %d, Total loss %0.6f, psnr %0.6f" % (step, loss, psnr[-1]))
        print(
            "Step %d, Total loss %0.6f, psnr %0.6f" % (step, loss_para, psnr_staf[-1])
        )


fig, axes = plt.subplots(1, 4, figsize=(18, 6))
axes[0].imshow(cameraman.pixels.view(imsize, imsize).detach().numpy(), cmap="gray")
axes[1].imshow(outim.view(imsize, imsize).detach().numpy(), cmap="gray")
axes[2].imshow(outim_para.view(imsize, imsize).detach().numpy(), cmap="gray")
axes[3].plot(psnr)
axes[3].plot(psnr_staf)
axes[3].legend(["Siren", "STAF"])
axes[3].grid()
axes[0].title.set_text("Original")
axes[1].title.set_text("Siren")
axes[2].title.set_text("STAF")
axes[3].title.set_text("PSNR")
fig.suptitle("Implicit Neural Representation", fontsize=16)

plt.savefig(f"result.png")  # Save as PNG image
plt.show()
