import os
import gradio as gr
import torch
import numpy as np
from PIL import Image
import gc

from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.util import tensor2im
from utils import update_counter

opt = TestOptions()
opt.initialize()
opt = opt.parser.parse_args(args=[])

opt.model = "pix2pix"
opt.name = "pretrained"
opt.checkpoints_dir = "./PhotoSketch/Checkpoints"
opt.dataset_mode = "test_dir"
opt.dataroot = "examples/"
opt.input_nc = 3
opt.output_nc = 1
opt.which_model_netG = "resnet_9blocks"
opt.which_direction = "AtoB"
opt.norm = "batch"
opt.no_dropout = True
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.isTrain = False
opt.use_cuda = torch.cuda.is_available()


data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)

data_template = next(iter(dataset))

def test_pretrained_gradio(input_image: Image.Image):
    if input_image is None:
        return None
    
    try:
        update_counter()
        input_array = np.array(input_image.convert("RGB")).astype(np.float32) / 255.0
        input_tensor = torch.from_numpy(input_array).permute(2, 0, 1).unsqueeze(0) * 2 - 1  # [1, 3, H, W]

        data = dict(data_template)
        data['A'] = input_tensor
        data['A_paths'] = ["gradio_input.jpg"]

        _, _, H, W = input_tensor.shape
        data['B'] = torch.zeros((1, 1, H, W), dtype=torch.float32)

        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()

        fake_B = visuals.get('fake_B')
        
        if fake_B.ndim == 4:
            fake_B = fake_B[0]
        
        output_image = Image.fromarray(fake_B)
        
        del input_tensor, data, fake_B, visuals
        torch.cuda.empty_cache()
        gc.collect()
        
        return output_image
    
    except Exception as e:
        print(f"Error processing image: {e}")
        torch.cuda.empty_cache()
        gc.collect()
        return None


demo = gr.Interface(
    fn=test_pretrained_gradio,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Image(type="pil", label="Output Image"),
    title="Photo Sketch Model",
    description="Upload an image. Model processes it and returns output without saving to disk."
)

demo.launch(share=True, server_name="0.0.0.0", server_port=8083)