import torch
import numpy as np
import os
import json
import random
from scipy.stats import pearsonr
from torch.nn.attention import SDPBackend, sdpa_kernel
from PIL import Image
from diffusers import StableDiffusion3Pipeline

# --- 1. load SD3.5 Medium ---
#
model_id = "stabilityai/stable-diffusion-3.5-medium"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
transformer = pipe.transformer
vae = pipe.vae

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


def compute_step_geometry(transformer, z, t, c, p, target_layers=range(38)):

    device = transformer.device
    dtype = torch.bfloat16

    #
    all_params = []
    for idx in target_layers:
        keyword = f"transformer_blocks.{idx}.attn.to_q"
        all_params.extend(
            [param for n, param in transformer.named_parameters() if keyword in n and param.requires_grad])

    if not all_params:
        return None

    with sdpa_kernel([SDPBackend.MATH]):
        #
        output = transformer(
            hidden_states=z,
            timestep=t,
            encoder_hidden_states=c,
            pooled_projections=p,
            return_dict=False
        )[0]

        # Loss
        loss = 0.5 * torch.sum(output ** 2)

        # 1.  LS (Local Scaling) -
        grads = torch.autograd.grad(loss, all_params, create_graph=True)
        flat_grads = torch.cat([g.reshape(-1) for g in grads])
        ls = torch.norm(flat_grads).item()

        # 2.  LC (Local Complexity) -  Power Iteration with Hessian
        #
        v = torch.randn_like(flat_grads)
        v = v / (torch.norm(v) + 1e-8)

        # maxium signi
        for _ in range(2):
            v_grads_scalar = torch.sum(flat_grads * v)
            hv = torch.autograd.grad(v_grads_scalar, all_params, retain_graph=True)
            flat_hv = torch.cat([g.contiguous().reshape(-1) for g in hv])
            v = flat_hv / (torch.norm(flat_hv) + 1e-8)

        lc = torch.dot(v, flat_hv).item()

        return {"step_lc": abs(lc), "step_ls": ls, "efficiency": abs(lc) / (ls + 1e-8)}

def encode_prompt_for_sd3(pipe, prompt, device, dtype):
    """
    apply SD3 Pipeline
    return: (prompt_embeds, pooled_prompt_embeds)
    """
    print(f" Prompt: '{prompt}' ...")
    with torch.no_grad():
        #
        #
        (
            prompt_embeds,
            _,
            pooled_prompt_embeds,
            _,
        ) = pipe.encode_prompt(
            prompt=prompt,
            prompt_2=None,
            prompt_3=None,
            device=device,
            do_classifier_free_guidance=False  #
        )
    return prompt_embeds.to(dtype=dtype), pooled_prompt_embeds.to(dtype=dtype)

def run_stepwise_profiling(
        pipe,
        prompt,
        seed=42,
        inference_steps=28,  #
        latent_size=128
):
    """
    record Denoise Step  geometry
    """
    device = pipe.device
    dtype = torch.bfloat16
    transformer = pipe.transformer

    #  Embedding
    with torch.no_grad():
        c_emb, p_emb = encode_prompt_for_sd3(pipe, prompt, device, dtype)
        # c_emb, _, p_emb, _ = pipe.encode_prompt(prompt=prompt, device=device)

    set_seed(seed)
    # latent
    z = torch.randn(1, 16, latent_size, latent_size, device=device, dtype=dtype, requires_grad=True)

    #
    #
    timesteps = np.linspace(999, 0, inference_steps)

    step_results = []

    print(f"Prompt: {prompt}")

    for i, t_val in enumerate(timesteps):
        t_tensor = torch.tensor([float(t_val)], device=device, dtype=dtype)

        metrics = compute_step_geometry(transformer, z, t_tensor, c_emb, p_emb)

        if metrics:
            metrics["t"] = t_val
            metrics["step_idx"] = i
            step_results.append(metrics)
            print(f"Step {i:02d} (t={t_val:4.1f}) | LC: {metrics['step_lc']:.2f} | LS: {metrics['step_ls']:.2f}")

    # average for all steps
    total_lc = np.mean([m["step_lc"] for m in step_results])
    total_ls = np.mean([m["step_ls"] for m in step_results])

    report = {
        "prompt": prompt,
        "seed": seed,
        "aggregate": {
            "avg_lc": total_lc,
            "avg_ls": total_ls,
            "complexity_cost": total_lc / total_ls
        },
        "steps": step_results
    }

    return report

if __name__ == "__main__":
   if transformer:
        report_normal = run_stepwise_profiling(pipe, "a normal dog", seed=1)
        report_mutant = run_stepwise_profiling(pipe, "a five-legged dog", seed=1)

   # for layer in [0,11,23]: #range(24)
       #      # run_batch_seeds_experiment(pipe, layer_idx=layer, seeds=range(0, 100), prompt='A perfect square table with four legs.')
       #      run_batch_seeds_experiment(pipe, layer_idx=layer, seeds=range(0, 100), prompt='A triangular table with five legs.')

   # be careful to make sure: unset LD_PRELOAD
   pass

