import torch
import numpy as np
from enum import Enum
import math

import torch.nn.functional as F
from utils.tools import resize_and_center_crop, numpy2pytorch, pad, decode_latents, encode_video

# Enum defining possible light source positions
class BGSource(Enum):
    NONE = "None"  # No specific light source
    LEFT = "Left Light"  # Light source from the left side
    RIGHT = "Right Light"  # Light source from the right side
    TOP = "Top Light"  # Light source from the top
    BOTTOM = "Bottom Light"  # Light source from the bottom

class Relighter:
    def __init__(self, 
                 pipeline,  # The diffusion pipeline (likely a StableDiffusionImg2ImgPipeline)
                 relight_prompt="",  # Prompt to guide the relighting process
                 num_frames=16,  # Number of frames in the video
                 image_width=512,  # Width of each frame
                 image_height=512,  # Height of each frame
                 num_samples=1,  # Number of samples to generate
                 steps=15,  # Number of denoising steps
                 cfg=2,  # Classifier-free guidance scale
                 lowres_denoise=0.9,  # Strength of denoising (preserves more original content at lower values)
                 bg_source=BGSource.RIGHT,  # Default light source position
                 generator=None,  # Random number generator for reproducibility
                 ):
        
        # Store parameters
        self.pipeline = pipeline
        self.image_width = image_width
        self.image_height = image_height
        self.num_samples = num_samples
        self.steps = steps
        self.cfg = cfg
        self.lowres_denoise = lowres_denoise
        self.bg_source = bg_source
        self.generator = generator
        self.device = pipeline.device
        self.num_frames = num_frames
        self.vae = self.pipeline.vae  # Variational Autoencoder for encoding/decoding images
        
        # Set up prompt conditioning
        self.a_prompt = "best quality"  # Positive quality prompt suffix
        self.n_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"  # Negative prompt
        positive_prompt = relight_prompt + ', ' + self.a_prompt
        negative_prompt = self.n_prompt
        tokenizer = self.pipeline.tokenizer
        device = self.pipeline.device
        vae = self.vae
        
        # Generate text embeddings for the prompts
        conds, unconds = self.encode_prompt_pair(tokenizer, device, positive_prompt, negative_prompt)
        
        # Create background for lighting direction
        input_bg = self.create_background()
        if self.bg_source == BGSource.NONE:
            # If no light source, create random latent noise
            shape = (1, 4, self.image_width//8, self.image_height//8)
            bg_latent = torch.randn(shape, generator=generator, device=device, dtype=vae.dtype)
        else:
            # Otherwise, encode the background gradient as a latent
            bg = resize_and_center_crop(input_bg, self.image_width, self.image_height)
            bg_latent = numpy2pytorch([bg], device, vae.dtype)
            bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
        
        # Repeat the background latent for all frames to maintain consistent lighting
        self.bg_latent = bg_latent.repeat(self.num_frames, 1, 1, 1)
        # Repeat text conditions for all frames
        self.conds = conds.repeat(self.num_frames, 1, 1)
        self.unconds = unconds.repeat(self.num_frames, 1, 1)
        
    def encode_prompt_inner(self, tokenizer, txt):
        """
        Tokenize and encode a text prompt using the text encoder
        """
        max_length = tokenizer.model_max_length
        chunk_length = tokenizer.model_max_length - 2
        id_start = tokenizer.bos_token_id
        id_end = tokenizer.eos_token_id
        id_pad = id_end

        # Tokenize the text without special tokens
        tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
        
        # Split long prompts into chunks and add start/end tokens
        chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
        chunks = [pad(ck, id_pad, max_length) for ck in chunks]

        # Convert to tensor and encode with text encoder
        token_ids = torch.tensor(chunks).to(device=self.device, dtype=torch.int64)
        conds = self.pipeline.text_encoder(token_ids).last_hidden_state
        return conds

    def encode_prompt_pair(self, tokenizer, device, positive_prompt, negative_prompt):
        """
        Encode both positive and negative prompts and ensure they have the same shape
        """
        c = self.encode_prompt_inner(tokenizer, positive_prompt)
        uc = self.encode_prompt_inner(tokenizer, negative_prompt)

        # Calculate repetition needed to make both tensors the same size
        c_len = float(len(c))
        uc_len = float(len(uc))
        max_count = max(c_len, uc_len)
        c_repeat = int(math.ceil(max_count / c_len))
        uc_repeat = int(math.ceil(max_count / uc_len))
        max_chunk = max(len(c), len(uc))

        # Repeat and truncate to match sizes
        c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
        uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]

        # Concatenate embeddings along dimension 1
        c = torch.cat([p[None, ...] for p in c], dim=1)
        uc = torch.cat([p[None, ...] for p in uc], dim=1)

        return c.to(device), uc.to(device)

    def create_background(self):
        """
        Create a gradient background representing the light direction
        """
        max_pix = 255  # Maximum pixel intensity (white)
        min_pix = 0    # Minimum pixel intensity (black)
        
        print(f"max light pix:{max_pix}, min light pix:{min_pix}")
        
        if self.bg_source == BGSource.NONE:
            return None
        elif self.bg_source == BGSource.LEFT:
            # Gradient from left (bright) to right (dark)
            gradient = np.linspace(max_pix, min_pix, self.image_width)
            image = np.tile(gradient, (self.image_height, 1))
            return np.stack((image,) * 3, axis=-1).astype(np.uint8)
        elif self.bg_source == BGSource.RIGHT:
            # Gradient from left (dark) to right (bright)
            gradient = np.linspace(min_pix, max_pix, self.image_width)
            image = np.tile(gradient, (self.image_height, 1))
            return np.stack((image,) * 3, axis=-1).astype(np.uint8)
        elif self.bg_source == BGSource.TOP:
            # Gradient from top (bright) to bottom (dark)
            gradient = np.linspace(max_pix, min_pix, self.image_height)[:, None]
            image = np.tile(gradient, (1, self.image_width))
            return np.stack((image,) * 3, axis=-1).astype(np.uint8)
        elif self.bg_source == BGSource.BOTTOM:
            # Gradient from top (dark) to bottom (bright)
            gradient = np.linspace(min_pix, max_pix, self.image_height)[:, None]
            image = np.tile(gradient, (1, self.image_width))
            return np.stack((image,) * 3, axis=-1).astype(np.uint8)
        else:
            raise ValueError('Wrong initial latent!')
    
    @torch.no_grad()
    def __call__(self, input_video, init_latent=None, input_strength=None):
        """
        Process the input video to create a relit version
        """
        # Encode the input video to latent space
        input_latent = encode_video(self.vae, input_video) * self.vae.config.scaling_factor
        
        # Determine strength of the relighting effect
        if input_strength:
            light_strength = input_strength
        else:
            light_strength = self.lowres_denoise

        # Use the background gradient latent if no specific init_latent is provided
        if not init_latent:
            init_latent = self.bg_latent

        # Run the diffusion pipeline to generate the relit video
        latents = self.pipeline(
            image=init_latent,  # Starting point (lighting gradient)
            strength=light_strength,  # How much to preserve of the original lighting
            prompt_embeds=self.conds,  # Positive prompt embeddings
            negative_prompt_embeds=self.unconds,  # Negative prompt embeddings
            width=self.image_width,
            height=self.image_height,
            num_inference_steps=int(round(self.steps / self.lowres_denoise)),
            num_images_per_prompt=self.num_samples,
            generator=self.generator,
            output_type='latent',  # Return latents instead of images
            guidance_scale=self.cfg,  # Classifier-free guidance scale
            cross_attention_kwargs={'concat_conds': input_latent},  # Pass the input video as a conditioning
        ).images.to(self.pipeline.vae.dtype)

        # Decode the generated latents back to pixel space
        relight_video = decode_latents(self.vae, latents)
        return relight_video