import sys
from pathlib import Path
sys.path.append(str(Path(__file__).absolute().parent.parent))
import torch
from typing import Optional, Union, List
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
from diffusers.utils.torch_utils import randn_tensor
from diffusers import DDIMPipeline, AutoencoderKL


class CondDDIMPipeline(DDIMPipeline):
    def __init__(self, net, scheduler, vae: Optional[AutoencoderKL] = None):
        super().__init__(net, scheduler)
        self.vae = vae
        self.register_modules(vae=vae)
        self.net = net

    def calculate_ito_increment(self, xt, dx, score_i, t, d_tau):
            # Retrieve scheduler constants
            beta_t = self.scheduler.betas[t]
            # d = total number of dimensions (C * H * W)
            d = torch.prod(torch.tensor(xt.shape[1:])).to(xt.device)

            # Term A: <dx, score_i> (Directional alignment)
            term_a = torch.sum(dx * score_i, dim=(1, 2, 3))

            # Term B: SDE Physics (Eq. 11 & 13)
            # Forward Drift f = -0.5 * beta * x
            drift_f = -0.5 * beta_t * xt
            # Forward Divergence <div, f> = -0.5 * beta * dimension
            div_f = -0.5 * beta_t * d
            
            # Inner Kernel: <f - (g^2/2) * s_i, s_i> where g^2 = beta_t
            inner_kernel = drift_f - (beta_t / 2.0) * score_i
            term_b_alignment = torch.sum(inner_kernel * score_i, dim=(1, 2, 3))

            # Combine as per Equation 13
            return term_a + (div_f + term_b_alignment) * d_tau
    
    # Initial log-density at t=T (pure noise)
    # log p(x) = -0.5 * (d * log(2pi) + ||x||^2)
    def get_initial_ll(self, x_T):
        """
        Calculates the log-density of the initial noise under a standard Gaussian prior.
        x_T: The initial noise tensor (batch_size, C, H, W)
        """
        batch_size = x_T.shape[0]
        
        # 1. Calculate the total number of dimensions (d)
        # For a 3x64x64 image, d = 12288
        dims = x_T.shape[1:]
        d = torch.prod(torch.tensor(dims)).to(x_T.device)
        
        # 2. Calculate the squared norm of the noise ||x||^2 
        # Flatten the image to (batch_size, d) then sum the squares
        norm_sq = torch.norm(x_T.view(batch_size, -1), dim=1)**2
        
        # 3. Compute log p(x) = -0.5 * (d * log(2pi) + ||x||^2)
        log_2pi = torch.log(torch.tensor(2 * torch.pi)).to(x_T.device)
        initial_ll = -0.5 * (d * log_2pi + norm_sq)
        
        return initial_ll

    

    """
    A PyTorch Lightning module that implements the DDIM pipeline for image data. Taken from the original DDIM implementation."""
    @torch.no_grad()
    def __call__(
        self,
        batch_size: int = 1,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        eta: float = 0.0,
        num_inference_steps: int = 50,
        use_clipped_model_output: Optional[bool] = None,
        return_dict: bool = True,
        query: Optional[torch.Tensor] = None,
        guidance_dict: Optional[dict] = None,
        null_token: Optional[torch.Tensor] = None,
        image: Optional[torch.Tensor] = None,
        noise_percentage: Optional[float] = None,
        **kwargs,
    ):
        """
        Args:
            noise_percentage: Optional float between 0 and 1. If provided, noise will be added to the input image
                            at the timestep corresponding to this percentage of total inference steps.
                            For example, 0.2 means add noise at 20% of the total timesteps.
                            If None, starts from pure random noise (default behavior).
        """
        if isinstance(self.net.config.sample_size, int):
            image_shape = (
                batch_size,
                self.net.config.in_channels,
                self.net.config.sample_size,
                self.net.config.sample_size,
            )
        elif self.net.config.in_channels ==0:
            image_shape = (batch_size, *self.net.config.sample_size)
        else:
            image_shape = (batch_size, self.net.config.in_channels, *self.net.config.sample_size)

        if isinstance(generator, list) and len(generator) != batch_size:
            raise ValueError(
                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
            )

        if image is None:
            image = randn_tensor(image_shape, generator=generator, device=null_token.device)
            image = image * self.scheduler.init_noise_sigma
        
        self.scheduler.set_timesteps(num_inference_steps)
        
        # for debugging
        if noise_percentage is not None:
            noise_step_idx = int(len(self.scheduler.timesteps) * noise_percentage)
            timesteps = torch.ones(image.size(0), device=null_token.device, dtype=torch.long) * self.scheduler.timesteps[noise_step_idx]
            noise = torch.randn_like(image)
            noisy_image = self.scheduler.add_noise(image, noise, timesteps)
            image = noisy_image
            self.scheduler.timesteps = self.scheduler.timesteps[noise_step_idx + 1:]
        
        image = image.to(device=null_token.device)
    
        # Initialize ALL possible conditions in the query tree
        initial_ll = self.get_initial_ll(image)

        def get_strings_query(expr, collector_list=None):
            if collector_list is None:
                collector_list = []
            collector_list.append(str(expr))

            if hasattr(expr, 'expression') and expr.expression is not None:
                get_strings_query(expr.expression, collector_list)
                
            if hasattr(expr, 'left') and hasattr(expr, 'right'):
                get_strings_query(expr.left, collector_list)
                get_strings_query(expr.right, collector_list)
                
            return collector_list

        all_query_strings = get_strings_query(query)

        ll_state = {k: initial_ll.clone() for k in all_query_strings}

        for t in self.progress_bar(self.scheduler.timesteps):
            t = t.to(device=image.device)
            xt_current = image.clone()
            model_output, score_cache = self.net(
                xt_current, 
                t, 
                query, 
                guidance_dict=guidance_dict, 
                null_token=null_token, 
                ll_state=ll_state, 
                scheduler=self.scheduler
            )
            step_results = self.scheduler.step(
                model_output, t, xt_current, eta=eta, 
                use_clipped_model_output=use_clipped_model_output, generator=generator
            )
            image = step_results.prev_sample
            xt_next = image

            # Ito density estimator https://arxiv.org/pdf/2412.17762
            dx = xt_next - xt_current

            # 5. Apply Theorem 1 Update for every model in the query
            # This must happen BEFORE the next iteration's OR weights are calculated
            if score_cache is not None:
                for key, score_i in score_cache.items():
                    # Inside your pipeline loop
                    t_idx = (self.scheduler.timesteps.to(t.device) == t).nonzero().item()

                    # Get current and next (actually previous in denoising) alpha
                    alpha_t = self.scheduler.alphas_cumprod[t]

                    if t_idx < len(self.scheduler.timesteps) - 1:
                        t_next = self.scheduler.timesteps[t_idx + 1]
                        alpha_next = self.scheduler.alphas_cumprod[t_next]
                        # d_tau is the change in 'time' in the variance space
                        d_tau = torch.abs(alpha_t - alpha_next) 
                    else:
                        # Small epsilon for the last step
                        d_tau = torch.tensor(1e-3).to(alpha_t.device)

                    # d log q = <dx, score_i> + (div_f + <f - (g^2/2)score_i, score_i>)
                    dll = self.calculate_ito_increment(xt_current, dx, score_i, t, d_tau)
                    
                    # Accumulate the path integral
                    if key not in ll_state:
                        # Calibrate with Gaussian Prior if it's the first step
                        ll_state[key] = self.get_initial_ll(xt_current)
                        
                    ll_state[key] += dll

        if self.vae is not None:
            image = self.vae.decode(image/self.vae.config.scaling_factor)[0]
            image = (image.clamp(-1, 1) + 1) / 2 
        if not return_dict:
            return (image,)
        return ImagePipelineOutput(images=image)
    





