import torch
import os
from tqdm import tqdm
from diffusers.utils import load_image, check_min_version
from models.camctrl_transformer import CamCtrlFluxTransformer2DModel
from pipelines.pipeline_camtrl_removal_fuse_pixel import FluxControlSingleScaleRemovalPipeline
from geocalib import GeoCalib
from PIL import Image

check_min_version("0.30.2") 

# ---------------------- Configuration ----------------------
imgs_folder = 'path_to_imgs_folder'
masks_folder = 'path_to_mask_folder'
output_folder = 'path_to_output_folder'
file_list_txt = 'path_to_txt_file'
test_zomm_ratio='1x'
prompt = 'There is nothing here.'
size = (1024, 1024)
os.makedirs(output_folder, exist_ok=True)

with open(file_list_txt, 'r') as f:
    filenames = [line.strip() for line in f if line.strip()]

print(f"Total files to process: {len(filenames)}")

# ---------------------- Build pipeline ----------------------
transformer = CamCtrlFluxTransformer2DModel.from_pretrained(
    'path_to_model',
    subfolder="transformer",
    torch_dtype=torch.bfloat16
)

with torch.no_grad(): 
    initial_input_channels = transformer.config.in_channels
    new_linear = torch.nn.Linear(
        transformer.x_embedder.in_features * 4,
        transformer.x_embedder.out_features,
        bias=transformer.x_embedder.bias is not None,
        dtype=transformer.dtype,
        device=transformer.device,
    )
    new_linear.weight.zero_()
    new_linear.weight[:, :initial_input_channels].copy_(transformer.x_embedder.weight)
    if transformer.x_embedder.bias is not None:
        new_linear.bias.copy_(transformer.x_embedder.bias)
    transformer.x_embedder = new_linear
    transformer.register_to_config(in_channels=initial_input_channels * 4)

pipe = FluxControlSingleScaleRemovalPipeline.from_pretrained(
    'path_to_model',
    transformer=transformer,
    torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.to(torch.bfloat16)

assert (
    pipe.transformer.config.in_channels == initial_input_channels * 4
), "Transformer input channels mismatch."

pipe.load_lora_weights('path_to_model', 
                       weight_name="pytorch_lora_weights.safetensors",
                       )

camera_model = GeoCalib(weights="pinhole").to("cuda")

# ----------------------inpainting ----------------------
for filename in tqdm(filenames, desc="Processing images"):
    image_path = os.path.join(imgs_folder, filename)
    mask_path = os.path.join(masks_folder, filename)

    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        continue
    if not os.path.exists(mask_path):
        print(f"Mask not found: {mask_path}")
        continue

    image = load_image(image_path).convert("RGB").resize(size)
    mask = load_image(mask_path).convert("RGB").resize(size)
    generator = torch.Generator(device="cuda").manual_seed(24)

    result = pipe(
        prompt=prompt,
        control_image=image,
        control_mask=mask,
        num_inference_steps=28,
        guidance_scale=3.5,
        generator=generator,
        max_sequence_length=512,
        height=size[1],
        width=size[0],
        camera_model=camera_model,
        test_zomm_ratio=test_zomm_ratio,
    ).images[0]

    result = result.resize((960, 540), Image.Resampling.LANCZOS)
    save_path = os.path.join(output_folder, filename)
    result.save(save_path)

print("All images processed successfully!")



