import os
import random
from dataclasses import dataclass, field

import torch
import torch.nn.functional as F
from diffusers import DDPMScheduler, UNet2DConditionModel
from diffusers.models import AutoencoderKL
from diffusers.training_utils import compute_snr
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image

from ..pipelines.pipeline_jigsaw3D_i2mv_sdxl import Jigsaw3DI2MVSDXLPipeline

from ..schedulers.scheduling_shift_snr import ShiftSNRScheduler
from ..utils.core import find
from ..utils.typing import *
from .base import BaseSystem
from .utils import encode_prompt, vae_encode

from ..utils.jigsaw import apply_jigsaw_mask, apply_content_block_jigsaw, vae_feature_apply_jigsaw_mask

def compute_embeddings(
    prompt_batch,
    empty_prompt_indices,
    text_encoders,
    tokenizers,
    is_train=True,
    **kwargs,
):

    original_size = kwargs["original_size"]
    target_size = kwargs["target_size"]
    crops_coords_top_left = kwargs["crops_coords_top_left"]

    for i in range(empty_prompt_indices.shape[0]):
        if empty_prompt_indices[i]:
            prompt_batch[i] = ""

    prompt_embeds, pooled_prompt_embeds = encode_prompt(
        prompt_batch, text_encoders, tokenizers, 0, is_train
    )
    add_text_embeds = pooled_prompt_embeds.to(
        device=prompt_embeds.device, dtype=prompt_embeds.dtype
    )

    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
    add_time_ids = list(original_size + crops_coords_top_left + target_size)
    add_time_ids = torch.tensor([add_time_ids])
    add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
    add_time_ids = add_time_ids.to(
        device=prompt_embeds.device, dtype=prompt_embeds.dtype
    )

    unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
    return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}


class MVAdapterImageSDXLSystem(BaseSystem):
    @dataclass
    class Config(BaseSystem.Config):
        # Model / Adapter
        pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-xl-base-1.0"
        pretrained_vae_name_or_path: Optional[str] = "madebyollin/sdxl-vae-fp16-fix"
        pretrained_adapter_name_or_path: Optional[str] = None
        pretrained_unet_name_or_path: Optional[str] = None
        init_adapter_kwargs: Dict[str, Any] = field(default_factory=dict)

        use_fp16_vae: bool = True
        use_fp16_clip: bool = True

        # Training
        trainable_modules: List[str] = field(default_factory=list)
        train_cond_encoder: bool = True
        prompt_drop_prob: float = 0.0
        image_drop_prob: float = 0.0
        cond_drop_prob: float = 0.0

        gradient_checkpointing: bool = False

        # Noise sampler
        noise_scheduler_kwargs: Dict[str, Any] = field(default_factory=dict)
        noise_offset: float = 0.0
        input_perturbation: float = 0.0
        snr_gamma: Optional[float] = 5.0
        prediction_type: Optional[str] = None
        shift_noise: bool = False
        shift_noise_mode: str = "interpolated"
        shift_noise_scale: float = 1.0

        # Evaluation
        eval_seed: int = 0
        eval_num_inference_steps: int = 30
        eval_guidance_scale: float = 1.0
        eval_height: int = 512
        eval_width: int = 512

    cfg: Config
    def configure(self):
        super().configure()

    ################################ The training stage to-be organized and then release
    