import os
import re
import glob
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from typing import Dict, Optional

from .base import BaseMetric

class BrightEndingMetric(BaseMetric):
    @property
    def name(self) -> str:
        return "BrightEnding_LD_Score"
        
    @property
    def metric_type(self) -> str:
        return "per_seed"

    @torch.no_grad()
    def measure(self, intermediates: Dict, **kwargs) -> Dict:
        """
        Calculates the Localized Detection (LD) score using a hybrid approach:
        1. Tries to get the BE map from the live controller's cache.
        2. If that fails, falls back to loading the map from disk.
        """
        controller = kwargs.get('controller')
        attention_map_dir = kwargs.get("attention_map_dir")
        
        uncond_noise = intermediates.get('uncond_noise', [])
        text_noise = intermediates.get('text_noise', [])
        
        # Global magnitude (unchanged)
        d_global = torch.mean(torch.stack([d.norm(p=2) for d in [(tn - un) for tn, un in zip(text_noise, uncond_noise)]])).item()
        
        be_map = None
        # --- Stage 1: Try to get map from live cache (for baseline run) ---
        if controller and controller.latest_attention_maps:
            final_step_maps = controller.latest_attention_maps.get(49, {})
            attn_map1 = final_step_maps.get('down_0')
            attn_map2 = final_step_maps.get('down_1')
            if attn_map1 is not None and attn_map2 is not None:
                final_attn_map = (attn_map1 + attn_map2) / 2.0
                be_map_raw = final_attn_map[:, :, -1]
                num_heads, latent_dim = 8, int((be_map_raw.shape[1])**0.5)
                be_map = be_map_raw.reshape(-1, num_heads, latent_dim, latent_dim).mean(1)

        # --- Stage 2: Fallback to loading from disk (for mitigated run or if cache fails) ---
        if be_map is None and attention_map_dir and os.path.exists(attention_map_dir):
            print("Warning: BE map not found in live cache. Attempting to load from disk...")
            all_final_step_maps = glob.glob(os.path.join(attention_map_dir, "step49_*_token*.png"))
            if all_final_step_maps:
                last_token_idx = max([int(re.search(r"_token(\d+)\.png", f).group(1)) for f in all_final_step_maps])
                map_files = glob.glob(os.path.join(attention_map_dir, f"step49_*_token{last_token_idx:02d}.png"))
                if map_files:
                    all_maps_from_disk = [torch.from_numpy(np.array(Image.open(f).convert("L")) / 255.0).float() for f in map_files]
                    be_map = torch.stack(all_maps_from_disk).mean(dim=0)
        
        if be_map is None:
            print("Warning: Could not find BE map from either live cache or disk. Returning only d_score.")
            return {"d_score": d_global}

        # --- CORRECTED: Use raw attention scores as weights (no normalization) ---
        # Remove this line: be_map = (be_map - be_map.min()) / (be_map.max() - be_map.min() + 1e-6)
        
        # Compute noise differences
        noise_diff_traj = [(tn - un) for tn, un in zip(text_noise, uncond_noise)]
        
        # Ensure be_map has batch dimension for interpolation
        if len(be_map.shape) == 2:
            be_map = be_map.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
        elif len(be_map.shape) == 3:
            be_map = be_map.unsqueeze(0)  # [1, B, H, W] -> assume B=1
        
        # CORRECTED: Keep everything on same device and use raw attention weights
        masked_diffs = []
        for diff in noise_diff_traj:
            # Interpolate BE map to match diff spatial dimensions
            be_map_resized = F.interpolate(
                be_map.to(diff.device), 
                size=diff.shape[2:], 
                mode='bilinear', 
                align_corners=False
            ).squeeze(0)  # Remove batch dim
            
            # Element-wise multiplication and L2 norm
            masked_diff = (diff * be_map_resized).norm(p=2)
            masked_diffs.append(masked_diff)
        
        # LD = (1/T) * sum(||(...) ○ m||_2) / (1/N * sum(m_i))
        T = len(masked_diffs)  # Number of timesteps
        numerator = torch.mean(torch.stack(masked_diffs))  # (1/T) * sum(...)
        
        # Use the resized be_map for denominator calculation (mean of attention weights)
        be_map_for_denom = F.interpolate(
            be_map.to(noise_diff_traj[0].device), 
            size=noise_diff_traj[0].shape[2:], 
            mode='bilinear', 
            align_corners=False
        ).squeeze()
        denominator = torch.mean(be_map_for_denom) + 1e-6  # (1/N * sum(m_i))
        
        ld_score = numerator / denominator
        
        print(f"{self.name} - LD Score: {ld_score.item()}, D Score: {d_global}, BE Mean: {torch.mean(be_map).item()}")
        return {
            "ld_score": ld_score.item(), 
            "d_score": d_global, 
            "be_intensity": torch.mean(be_map).item()
        }