import argparse
import os

import rembg
import torch
from PIL import Image
from tqdm import tqdm
from dataclasses import dataclass, field
from typing import Optional

from stable_fast_3d.sf3d.system import SF3D
from stable_fast_3d.sf3d.utils import remove_background, resize_foreground



def SF3D_img2mesh(image_path, output_dir):
    @dataclass
    class Config:
        pretrained_model: str = "stabilityai/stable-fast-3d"
        device: str = "cuda:0"
        image: Optional[str] = "./"
        output_dir: Optional[str] = "./"
        foreground_ratio: float = 0.85
        batch_size: int = 1
        texture_resolution: int = 1024
        remesh_option: str = "none"
    cfg: Config
    # implement Config 
    cfg = Config(
        image=image_path,
        output_dir=output_dir
    )
    # Ensure args.device contains cuda
    if "cuda" not in cfg.device:
        raise ValueError(
            "CUDA device is required for baking and hence running the method."
        )
    cfg.image = image_path
    cfg.output_dir = output_dir
    os.makedirs(output_dir, exist_ok=True)

    device = cfg.device
    if not torch.cuda.is_available():
        device = "cpu"

    model = SF3D.from_pretrained(
        # cfg.pretrained_model,
        "/home/amax/.cache/huggingface/hub/models--stabilityai--stable-fast-3d/snapshots/56d07dee021eacfa8c083310a7d9c63bcbf5d989",
        config_name="config.yaml",
        weight_name="model.safetensors",
    )
    model.to(device)
    model.eval()

    rembg_session = rembg.new_session()
    images = []
    idx = 0

    image_path = cfg.image
    # for image_path in cfg.image:

    def handle_image(image_path, idx):
        image = remove_background(
            Image.open(image_path).convert("RGBA"), rembg_session
        )
        image = resize_foreground(image, cfg.foreground_ratio)
        os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True)
        image.save(os.path.join(output_dir, str(idx), "input.png"))
        images.append(image)
    if os.path.isdir(image_path):
        image_paths = [
            os.path.join(image_path, f)
            for f in os.listdir(image_path)
            if f.endswith((".png", ".jpg", ".jpeg"))
        ]
        for image_path in image_paths:
            handle_image(image_path, idx)
            idx += 1
    else:
        handle_image(image_path, idx)
        idx += 1

    for i in tqdm(range(0, len(images), cfg.batch_size)):
        image = images[i : i + cfg.batch_size]
        torch.cuda.reset_peak_memory_stats()
        with torch.no_grad():
            with torch.autocast(device_type="cuda", dtype=torch.float16):
                mesh, glob_dict = model.run_image(
                    image,
                    bake_resolution=cfg.texture_resolution,
                    remesh=cfg.remesh_option,
                )
        print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")

        if len(image) == 1:
            out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
            mesh.export(out_mesh_path, include_normals=True)
        else:
            for j in range(len(mesh)):
                out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
                mesh[j].export(out_mesh_path, include_normals=True)

        return mesh