import os
import torch
import json
from PIL import Image
import numpy as np
from tqdm import tqdm
from functools import partial

from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig


class FluxKontextInfer:

    def __init__(self, config, seed=42):
        self.config = config
        self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
        self.prompt = config["prompt"]
        self.negative_prompt = config.get("negative_prompt", "")
        self.guidance_scale = config.get("guidance_scale", 2.5)
        self.seed = seed
        self.num_inference_steps = config.get("num_inference_steps", 30)
        self.upscale = config.get("upscale", True)

        self.pipe = None
        self.input_image = None
        self.generated_image = None

        # These two parameters are no longer read from JSON; determined by input image
        self.original_size = None
        self.target_size = None

        # Load model
        self.setup()

    def setup(self, model_id='/mnt/[anonymized_public_path]/DiffSynth-Studio/models/black-forest-labs/FLUX.1-Kontext-dev'):
        # print(f"Using device: {self.device}")
        if torch.cuda.is_available() and self.device == "cuda":
            try:
                torch.cuda.set_per_process_memory_fraction(1.0, 0)
            except Exception as e:
                print(f"Failed to set GPU memory usage ratio: {e}")
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

        # print("Loading FLUX.1-Kontext model ...")
        self.pipe = FluxImagePipeline.from_pretrained(
            torch_dtype=torch.bfloat16,
            device=self.device,
            model_configs=[
                ModelConfig(
                    model_id=model_id,
                    origin_file_pattern="flux1-kontext-dev.safetensors",
                    offload_device="cpu"
                ),
                ModelConfig(
                    model_id=model_id,
                    origin_file_pattern="text_encoder/model.safetensors",
                    offload_device="cpu"
                ),
                ModelConfig(
                    model_id=model_id,
                    origin_file_pattern="text_encoder_2/",
                    offload_device="cpu"
                ),
                ModelConfig(
                    model_id=model_id,
                    origin_file_pattern="ae.safetensors",
                    offload_device="cpu"
                ),
            ],
            skip_download=True,
        )
        self.pipe.enable_vram_management()
        # print("Model loaded.")

    @torch.inference_mode()
    def generate(self, img, enable_tqdm=True):
        '''
        img: torch.Tensor, shape = (b c h w), value range = (-1, 1)
        return: torch.Tensor (same shape and range as input img)
        '''
        assert img.shape[0] == 1  # Restrict GPU memory usage

        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        try:
            torch.manual_seed(int(self.seed))
        except Exception:
            torch.manual_seed(6)

        # Preprocess input tensor to PIL image
        img, w, h, ow, oh = self.torch_to_PIL(img[0])

        pipe_kwargs = dict(
            prompt=self.prompt,
            embedded_guidance=float(self.guidance_scale),
            num_inference_steps=int(self.num_inference_steps),
            negative_prompt=self.negative_prompt,
            seed=self.seed,
            width=int(w),
            height=int(h),
            kontext_images=img,
            progress_bar_cmd=tqdm if enable_tqdm else partial(tqdm, disable=True),
            # tiled=True  # Uncomment for VAE tiling (reduces memory usage)
        )

        result = self.pipe(**pipe_kwargs)  # pyright: ignore
        if hasattr(result, 'images'):
            generated_image = result.images[0]
        else:
            generated_image = result

        # Postprocess PIL image back to tensor
        generated_image = self.PIL_to_torch(generated_image, (ow, oh))

        return generated_image

    def torch_to_PIL(self, img):
        """Convert torch tensor (C, H, W, range [-1,1]) to PIL image"""
        img = img.float().cpu()
        img = img.permute((1, 2, 0))  # Rearrange from (C, H, W) to (H, W, C)
        img = (img + 1.0) / 2.0 * 255  # Rescale range from [-1,1] to [0,255]
        img = img.clamp(0, 255).numpy().astype('uint8')  # Ensure valid pixel values
        img = Image.fromarray(img)

        orig_width, orig_height = img.size
        if self.upscale:
            target_size = (orig_width * 4, orig_height * 4)
            img = img.resize(target_size, Image.LANCZOS)  # Upscale with high-quality filter
            width, height = target_size
        else:
            width, height = orig_width, orig_height
        return img, width, height, orig_width, orig_height

    def PIL_to_torch(self, img, size):
        """Convert PIL image to torch tensor (1, C, H, W, range [-1,1])"""
        img = img.resize(size, Image.LANCZOS)  # Resize to original input size
        img = (np.array(img).astype('float') / 255.0) * 2.0 - 1.0  # Rescale to [-1,1]
        img = torch.from_numpy(img).permute((2, 0, 1)).unsqueeze(0)  # Rearrange to (1, C, H, W)
        return img

    @staticmethod
    def from_json(json_path, device='cuda:0'):
        """Initialize FluxKontextInfer from a JSON configuration file"""
        with open(json_path, "r") as f:
            config = json.load(f)

        config['device'] = device
        return FluxKontextInfer(config)


def prepare_relighter(relight_type='warm_left', device='cuda:0'):
    """Prepare a relighter instance based on the specified lighting type"""
    all_configs = {
        'warm_left': '/mnt/[anonymized_public_path]/DiffSynth-Studio/relight_warm_left.json',
        'evening': '/mnt/[anonymized_public_path]/DiffSynth-Studio/relight_evening.json',
        'dusk': '/mnt/[anonymized_public_path]/DiffSynth-Studio/relight_dusk.json',
        'streetlight': '/mnt/[anonymized_public_path]/DiffSynth-Studio/relight_streetlight.json',
        'neon': '/mnt/[anonymized_public_path]/DiffSynth-Studio/relight_neon.json',
        'spotlight': '/mnt/[anonymized_public_path]/DiffSynth-Studio/relight_spotlight.json',
        'bg_classroom': '/mnt/[anonymized_public_path]/DiffSynth-Studio/bg_classroom.json',
        'bg_wood': '/mnt/[anonymized_public_path]/DiffSynth-Studio/bg_wood.json',
        'bg_dirt': '/mnt/[anonymized_public_path]/DiffSynth-Studio/bg_dirt.json',
        'bg_grass': '/mnt/[anonymized_public_path]/DiffSynth-Studio/bg_grass.json',
        'inpaint_watch': '/mnt/[anonymized_public_path]/DiffSynth-Studio/inpaint_watch.json',
        'inpaint_apple': '/mnt/[anonymized_public_path]/DiffSynth-Studio/inpaint_apple.json'
    }
    config_file = all_configs[relight_type]
    print(f'[INFO] Loading relight config: {config_file}')

    relighter = FluxKontextInfer.from_json(config_file, device)
    return relighter
