import os
import sys
import torch
import random
from tqdm import trange
from ..camera_utils import get_camera  

class SAILOptimizer:
    """
    Implements the Sharpness-Aware Initialization (SAIL) algorithm to find
    an optimal initial noise `x_T` that mitigates memorization.
    """
    def __init__(self, model, **kwargs):
        self.model = model
        self.device = model.device
        # Hyperparameters from the paper
        self.optim_steps = kwargs.get("optim_steps", 20)
        self.lr = kwargs.get("lr", 0.05)
        self.alpha = kwargs.get("alpha", 0.05) # Regularization term
        self.delta = kwargs.get("delta", 1e-3) # Perturbation for finite difference
        self.early_stop_thres = kwargs.get("l_thres", 8.0) # From paper appendix

    def _get_score_diff(self, x_t, t, c_, uc_):
        """Helper to compute the score difference s_delta."""
        cond_noise = self.model.apply_model(x_t, t, c_)
        uncond_noise = self.model.apply_model(x_t, t, uc_)
        return cond_noise - uncond_noise

    def optimize_noise(self, prompt: str, base_seed: int):
        """
        Runs the optimization loop for a given prompt to find a better initial noise `x_T`.
        """
        generator = torch.Generator(device=self.device).manual_seed(base_seed)
        shape = (self.model.num_frames, self.model.model.diffusion_model.in_channels, 
                 self.model.image_size // 8, self.model.image_size // 8)
        
        x_T = torch.randn(shape, generator=generator, device=self.device)
        x_T.requires_grad = True

        optimizer = torch.optim.Adam([x_T], lr=self.lr)
        
        c = self.model.get_learned_conditioning([prompt]).to(self.device)
        uc = self.model.get_learned_conditioning([""]).to(self.device)
        c_ = {"context": c.repeat(self.model.num_frames, 1, 1)}
        uc_ = {"context": uc.repeat(self.model.num_frames, 1, 1)}
        camera = get_camera(self.model.num_frames, elevation=random.randint(-15, 30), azimuth_start=random.randint(0, 360), azimuth_span=360)
        c_["camera"] = uc_["camera"] = camera.repeat(1, 1).to(self.device)
        # print(c_["camera"].shape, c_["context"].shape)
        c_["num_frames"] = uc_["num_frames"] = self.model.num_frames
        
        t = torch.tensor([999]*self.model.num_frames, device=self.device) # First denoising step

        for _ in range(self.optim_steps):
            optimizer.zero_grad()
            
            s_delta_at_x = self._get_score_diff(x_T, t, c_, uc_)
            
            with torch.no_grad():
                s_delta_norm = torch.linalg.norm(s_delta_at_x)
                if s_delta_norm == 0: break
                perturbation = self.delta * s_delta_at_x / s_delta_norm
            
            s_delta_at_x_perturbed = self._get_score_diff(x_T + perturbation, t, c_, uc_)

            # SAIL loss from Algorithm 2
            sharpness_loss = torch.sum((s_delta_at_x_perturbed - s_delta_at_x)**2)
            regularization_loss = self.alpha * torch.sum(x_T**2)
            loss = sharpness_loss + regularization_loss
            
            if loss.item() < self.early_stop_thres:
                break
            
            loss.backward()
            optimizer.step()

        return x_T.detach()