#!/usr/bin/env python
# coding=utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, Union, List, Tuple
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from models.mask_aware_attention import SpatiallyModulatedAttention
import math
import os

class LatentUnetEncoder(nn.Module):
    def __init__(self, latent_dim: int = 4, unet_dim: int = 1):
        super().__init__()
        self.latent_dim = latent_dim
        self.unet_dim = unet_dim
        
        self.mask_encoder = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, latent_dim, 3, padding=1),
            nn.Tanh()
        )
        
    def encode_unet(self, mask: torch.Tensor) -> torch.Tensor:
        unet_offset = self.mask_encoder(mask) * 0.1
        return unet_offset

class LayerAwareAttention(nn.Module):
    def __init__(self, original_attention: Attention, rank: int = 16, alpha: float = 1.0):
        super().__init__()
        self.original_attention = original_attention
        self.rank = rank
        self.alpha = alpha
        
        self.fg_to_q = LoRALinear(original_attention.to_q, rank, alpha, seed_offset=0)
        self.fg_to_k = LoRALinear(original_attention.to_k, rank, alpha, seed_offset=1)
        self.fg_to_v = LoRALinear(original_attention.to_v, rank, alpha, seed_offset=2)
        self.fg_to_out = LoRALinear(original_attention.to_out[0], rank, alpha, seed_offset=3)
        
        self.bg_to_q = LoRALinear(original_attention.to_q, rank, alpha, seed_offset=100)
        self.bg_to_k = LoRALinear(original_attention.to_k, rank, alpha, seed_offset=101)
        self.bg_to_v = LoRALinear(original_attention.to_v, rank, alpha, seed_offset=102)
        self.bg_to_out = LoRALinear(original_attention.to_out[0], rank, alpha, seed_offset=103)
        
        self.inter_layer_gate = nn.Parameter(torch.zeros(1))
        self.cross_layer_attn = nn.MultiheadAttention(
            embed_dim=original_attention.inner_dim,
            num_heads=original_attention.heads,
            batch_first=True
        )
        
        self.harmonization_weight = nn.Parameter(torch.ones(1) * 0.1)
        
        self.heads = original_attention.heads
        self.dim_head = getattr(original_attention, 'dim_head', 
                               original_attention.inner_dim // original_attention.heads)
        self.scale = original_attention.scale

class LoRALinear(nn.Module):
    def __init__(self, original_layer: nn.Linear, rank: int = 16, alpha: float = 1.0, seed_offset: int = 0):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha
        
        dtype = original_layer.weight.dtype
        device = original_layer.weight.device
        
        generator = torch.Generator(device=device)
        generator.manual_seed(42 + seed_offset)
        
        self.lora_A = nn.Parameter(torch.randn(original_layer.in_features, rank, dtype=dtype, device=device, generator=generator) * 0.1)
        self.lora_B = nn.Parameter(torch.randn(rank, original_layer.out_features, dtype=dtype, device=device, generator=generator) * 0.01)
        
        self.lora_A.requires_grad_(True)
        self.lora_B.requires_grad_(True)
        
        for param in self.original_layer.parameters():
            param.requires_grad = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        original_out = self.original_layer(x)
        lora_out = x @ self.lora_A.to(x.dtype) @ self.lora_B.to(x.dtype) * (self.alpha / self.rank)
        return original_out + lora_out

class ImprovedDualLoRAAttention(LayerAwareAttention):
    def __init__(self, original_attention: Attention, rank: int = 16, alpha: float = 1.0, use_spatially_modulated: bool = False):
        super().__init__(original_attention, rank, alpha)
        self.use_spatially_modulated = use_spatially_modulated
        
        if self.use_spatially_modulated:
            dim = getattr(original_attention, 'embed_dim', None) or getattr(original_attention, 'inner_dim', None)
            num_heads = getattr(original_attention, 'num_heads', None) or getattr(original_attention, 'heads', None)
            self.spatially_modulated_attn = SpatiallyModulatedAttention(dim, num_heads, use_mask_aware=True, out_proj=original_attention.to_out[0])
        else:
            self.spatially_modulated_attn = None
        
        self.fg_region_processor = ForegroundRegionProcessor(
            dim=original_attention.inner_dim,
            num_heads=original_attention.heads
        )
    
    def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, 
                attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
        mask = kwargs.get("mask", None)
        branch = kwargs.get("branch", "both")
        
        if "cross_attention_kwargs" in kwargs:
            cross_attn_kwargs = kwargs["cross_attention_kwargs"]
            if isinstance(cross_attn_kwargs, dict):
                mask = cross_attn_kwargs.get("mask", mask)
                branch = cross_attn_kwargs.get("branch", branch)
        
        batch_size, sequence_length, _ = hidden_states.shape
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        
        if branch == "foreground":
            return self._compute_foreground_region_attention(hidden_states, encoder_hidden_states, mask)
        elif branch == "background": 
            return self._compute_branch_attention(hidden_states, encoder_hidden_states, "background")
        else:
            return self._compute_layer_aware_attention(hidden_states, encoder_hidden_states, mask)
    
    def _compute_foreground_region_attention(self, hidden_states: torch.Tensor, 
                                           encoder_hidden_states: torch.Tensor, 
                                           mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, sequence_length, _ = hidden_states.shape
        
        return self.fg_region_processor(
            hidden_states=hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            mask=mask,
            fg_attention_fn=lambda h, e: self._compute_branch_attention(h, e, "foreground")
        )
    
    def _compute_branch_attention(self, hidden_states: torch.Tensor, 
                                encoder_hidden_states: torch.Tensor, branch: str) -> torch.Tensor:
        batch_size, sequence_length, _ = hidden_states.shape
        
        if branch == "foreground":
            q = self.fg_to_q(hidden_states)
            k = self.fg_to_k(encoder_hidden_states)
            v = self.fg_to_v(encoder_hidden_states)
            out_proj = self.fg_to_out
        else:
            q = self.bg_to_q(hidden_states)
            k = self.bg_to_k(encoder_hidden_states)
            v = self.bg_to_v(encoder_hidden_states)
            out_proj = self.bg_to_out
        
        q = q.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2)
        k = k.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2)
        v = v.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2)
        
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_probs = F.softmax(attn_scores, dim=-1)
        hidden_states = torch.matmul(attn_probs, v)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * self.dim_head)
        
        return out_proj(hidden_states)
    
    def _compute_layer_aware_attention(self, hidden_states: torch.Tensor,
                                     encoder_hidden_states: torch.Tensor, 
                                     mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, sequence_length, _ = hidden_states.shape
        
        fg_q = self.fg_to_q(hidden_states)
        fg_k = self.fg_to_k(encoder_hidden_states)
        fg_v = self.fg_to_v(encoder_hidden_states)
        
        bg_q = self.bg_to_q(hidden_states)
        bg_k = self.bg_to_k(encoder_hidden_states)
        bg_v = self.bg_to_v(encoder_hidden_states)
        
        if self.inter_layer_gate > 0.1:
            fg_features = torch.cat([fg_q, fg_k, fg_v], dim=-1)
            bg_features = torch.cat([bg_q, bg_k, bg_v], dim=-1)
            
            enhanced_fg, _ = self.cross_layer_attn(fg_features, bg_features, bg_features)
            enhanced_bg, _ = self.cross_layer_attn(bg_features, fg_features, fg_features)
            
            gate_weight = torch.sigmoid(self.inter_layer_gate)
            fg_features = (1 - gate_weight) * fg_features + gate_weight * enhanced_fg
            bg_features = (1 - gate_weight) * bg_features + gate_weight * enhanced_bg
            
            fg_q, fg_k, fg_v = torch.chunk(fg_features, 3, dim=-1)
            bg_q, bg_k, bg_v = torch.chunk(bg_features, 3, dim=-1)
        
        def compute_attention(q, k, v):
            q = q.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2)
            k = k.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2)
            v = v.view(batch_size, -1, self.heads, self.dim_head).transpose(1, 2)
            
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
            attn_probs = F.softmax(attn_scores, dim=-1)
            out = torch.matmul(attn_probs, v)
            return out.transpose(1, 2).reshape(batch_size, -1, self.heads * self.dim_head)
        
        fg_result = self.fg_to_out(compute_attention(fg_q, fg_k, fg_v))
        bg_result = self.bg_to_out(compute_attention(bg_q, bg_k, bg_v))
        
        if (self.use_spatially_modulated and self.spatially_modulated_attn is not None and 
            encoder_hidden_states is hidden_states):
            
            if mask is not None:
                h = w = int(math.sqrt(sequence_length))
                if mask.shape[-2:] != (h, w):
                    mask_resized = F.interpolate(mask.float(), size=(h, w), mode='nearest')
                else:
                    mask_resized = mask.float()
                mask_token = mask_resized.view(batch_size, h * w)
                mask_token = (mask_token > 0.5).float()
            else:
                mask_token = torch.zeros(batch_size, sequence_length, device=hidden_states.device)
            
            q = self.fg_to_q(hidden_states)
            k = self.fg_to_k(encoder_hidden_states)
            v = self.fg_to_v(encoder_hidden_states)
            
            return self.spatially_modulated_attn(q, k, v, mask_token=mask_token)
        
        if mask is not None:
            h = w = int(math.sqrt(sequence_length))
            if mask.shape[-2:] != (h, w):
                mask_resized = F.interpolate(mask.float(), size=(h, w), mode='nearest')
            else:
                mask_resized = mask.float()
            mask_token = mask_resized.view(batch_size, h * w, 1)
            
            harmonized_result = (
                fg_result * mask_token * self.harmonization_weight +
                bg_result * (1 - mask_token * self.harmonization_weight)
            )
            return harmonized_result
        else:
            return fg_result


class ForegroundRegionProcessor(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        
        self.original_feature_gate = nn.Parameter(torch.ones(1))
        
        self.boundary_predictor = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
        
        self.unet_encoder = nn.Sequential(
            nn.Linear(dim, dim // 2),
            nn.ReLU(),
            nn.Linear(dim // 2, dim)
        )
        
        self.boundary_width = 16
        
    def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, 
                mask: Optional[torch.Tensor] = None, fg_attention_fn=None) -> torch.Tensor:
        batch_size, sequence_length, dim = hidden_states.shape
        
        if mask is None:
            return fg_attention_fn(hidden_states, encoder_hidden_states)
        
        h = w = int(math.sqrt(sequence_length))
        if mask.shape[-2:] != (h, w):
            mask_resized = F.interpolate(mask.float(), size=(h, w), mode='nearest')
        else:
            mask_resized = mask.float()
        
        mask_token = mask_resized.view(batch_size, h * w)
        
        core_mask = (mask_token > 0.5).float()
        
        boundary_mask = self._compute_boundary_mask(mask_resized, self.boundary_width)
        boundary_mask = boundary_mask.view(batch_size, h * w)
        
        unet_mask = 1.0 - core_mask - boundary_mask
        unet_mask = torch.clamp(unet_mask, 0, 1)
        
        core_mask = core_mask.unsqueeze(-1)
        boundary_mask = boundary_mask.unsqueeze(-1)
        unet_mask = unet_mask.unsqueeze(-1)
        
        original_features = hidden_states * self.original_feature_gate
        
        boundary_features = fg_attention_fn(hidden_states, encoder_hidden_states)
        boundary_features = self.boundary_predictor(boundary_features)
        
        unet_features = self.unet_encoder(hidden_states)
        unet_features = unet_features * 0.1
        
        combined_features = (
            original_features * core_mask +
            boundary_features * boundary_mask +
            unet_features * unet_mask
        )
        
        return combined_features
    
    def _compute_boundary_mask(self, mask: torch.Tensor, boundary_width: int) -> torch.Tensor:
        batch_size, _, h, w = mask.shape
        
        boundary_masks = []
        
        for b in range(batch_size):
            single_mask = mask[b, 0]
            
            kernel_size = boundary_width * 2 + 1
            kernel = torch.ones(kernel_size, kernel_size, device=mask.device) / (kernel_size ** 2)
            kernel = kernel.unsqueeze(0).unsqueeze(0)
            
            single_mask_expanded = single_mask.unsqueeze(0).unsqueeze(0)
            padding = kernel_size // 2
            dilated = F.conv2d(single_mask_expanded, kernel, padding=padding)
            dilated = (dilated > 0.1).float()
            
            boundary = dilated - single_mask_expanded
            boundary = torch.clamp(boundary, 0, 1)
            
            boundary_masks.append(boundary.squeeze(0))
        
        return torch.stack(boundary_masks, dim=0)

class LayerUNet(ModelMixin, ConfigMixin):
    def __init__(self, original_unet, lora_rank: int = 16, lora_alpha: float = 1.0, use_spatially_modulated: bool = False):
        super().__init__()
        self.original_unet = original_unet
        self.lora_rank = lora_rank
        self.lora_alpha = lora_alpha
        self.use_spatially_modulated = use_spatially_modulated
        
        self.unet_encoder = LatentUnetEncoder()
        
        self._replace_attention_layers()
        
    def _replace_attention_layers(self):
        def replace_attention_recursive(module, name=""):
            for child_name, child_module in module.named_children():
                if isinstance(child_module, Attention):
                    new_attention = ImprovedDualLoRAAttention(
                        child_module, self.lora_rank, self.lora_alpha, self.use_spatially_modulated
                    )
                    setattr(module, child_name, new_attention)
                else:
                    replace_attention_recursive(child_module, f"{name}.{child_name}")
        
        replace_attention_recursive(self.original_unet)
    
    def forward(self, sample: torch.Tensor, timestep, encoder_hidden_states: torch.Tensor,
                mask: Optional[torch.Tensor] = None, branch: str = "both", 
                added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, **kwargs):
        
        if added_cond_kwargs is not None and "text_embeds" in added_cond_kwargs:
            text_embeds = added_cond_kwargs["text_embeds"]
            if text_embeds.dim() == 3 and text_embeds.shape[1] == 1:
                added_cond_kwargs = added_cond_kwargs.copy()
                added_cond_kwargs["text_embeds"] = text_embeds.squeeze(1)
            
            if "time_ids" in added_cond_kwargs:
                time_ids = added_cond_kwargs["time_ids"]
                if time_ids.dim() == 3 and time_ids.shape[1] == 1:
                    added_cond_kwargs["time_ids"] = time_ids.squeeze(1)
        
        if mask is not None:
            unet_offset = self.unet_encoder.encode_unet(mask)
            if sample.shape[-2:] != unet_offset.shape[-2:]:
                unet_offset = F.interpolate(
                    unet_offset, 
                    size=sample.shape[-2:], 
                    mode='bilinear', 
                    align_corners=False
                )
            
            if sample.shape[1] != unet_offset.shape[1]:
                if sample.shape[1] > unet_offset.shape[1]:
                    repeat_factor = sample.shape[1] // unet_offset.shape[1]
                    remainder = sample.shape[1] % unet_offset.shape[1]
                    unet_offset = unet_offset.repeat(1, repeat_factor, 1, 1)
                    if remainder > 0:
                        unet_offset = torch.cat([
                            unet_offset, 
                            unet_offset[:, :remainder, :, :]
                        ], dim=1)
                else:
                    unet_offset = unet_offset[:, :sample.shape[1], :, :]
            
            sample = sample + unet_offset
        
        if "cross_attention_kwargs" not in kwargs or kwargs["cross_attention_kwargs"] is None:
            kwargs["cross_attention_kwargs"] = {}
        
        kwargs["cross_attention_kwargs"]["mask"] = mask
        kwargs["cross_attention_kwargs"]["branch"] = branch
        
        return self.original_unet(
            sample=sample,
            timestep=timestep,
            encoder_hidden_states=encoder_hidden_states,
            added_cond_kwargs=added_cond_kwargs,
            **kwargs
        )
    
    def get_lora_parameters(self):
        lora_params = []
        def collect_lora_params(module):
            if hasattr(module, 'fg_to_q') and hasattr(module.fg_to_q, 'lora_A'):
                lora_params.extend([
                    module.fg_to_q.lora_A, module.fg_to_q.lora_B,
                    module.fg_to_k.lora_A, module.fg_to_k.lora_B,
                    module.fg_to_v.lora_A, module.fg_to_v.lora_B,
                    module.fg_to_out.lora_A, module.fg_to_out.lora_B,
                    module.bg_to_q.lora_A, module.bg_to_q.lora_B,
                    module.bg_to_k.lora_A, module.bg_to_k.lora_B,
                    module.bg_to_v.lora_A, module.bg_to_v.lora_B,
                    module.bg_to_out.lora_A, module.bg_to_out.lora_B
                ])
                
                if hasattr(module, 'inter_layer_gate'):
                    lora_params.append(module.inter_layer_gate)
                if hasattr(module, 'harmonization_weight'):
                    lora_params.append(module.harmonization_weight)
                
                if hasattr(module, 'spatially_modulated_attn') and module.spatially_modulated_attn is not None:
                    lora_params.extend([
                        module.spatially_modulated_attn.alpha_fg2fg,
                        module.spatially_modulated_attn.alpha_fg2bg,
                        module.spatially_modulated_attn.alpha_bg2fg,
                        module.spatially_modulated_attn.alpha_bg2bg
                    ])
                
                if hasattr(module, 'fg_region_processor') and module.fg_region_processor is not None:
                    lora_params.extend([
                        module.fg_region_processor.original_feature_gate,
                        *module.fg_region_processor.boundary_predictor.parameters(),
                        *module.fg_region_processor.unet_encoder.parameters()
                    ])
                    
            for child in module.children():
                collect_lora_params(child)
        
        collect_lora_params(self.original_unet)
        
        lora_params.extend(list(self.unet_encoder.parameters()))
        
        return lora_params
    
    def load_lora_weights(self, load_path: str, load_alpha: bool = True):
        import torch
        import os
        
        if not os.path.exists(load_path):
            print(f"❌ LoRA weights file not found: {load_path}")
            print("   Model will use randomly initialized LoRA weights")
            return
            
        try:
            lora_state_dict = torch.load(load_path, map_location="cpu")
        except Exception as e:
            print(f"❌ Failed to load LoRA weights file: {e}")
            print("   Model will use randomly initialized LoRA weights")
            return
        
        def load_lora_state_dict(module, prefix=""):
            if hasattr(module, 'fg_to_q') and hasattr(module.fg_to_q, 'lora_A'):
                if f"{prefix}fg_to_q.lora_A" in lora_state_dict:
                    try:
                        saved_lora_A = lora_state_dict[f"{prefix}fg_to_q.lora_A"]
                        saved_lora_B = lora_state_dict[f"{prefix}fg_to_q.lora_B"]
                        
                        expected_lora_A_shape = module.fg_to_q.lora_A.shape
                        expected_lora_B_shape = module.fg_to_q.lora_B.shape
                        
                        print(f"Debug: Expected lora_A shape: {expected_lora_A_shape}, got: {saved_lora_A.shape}")
                        print(f"Debug: Expected lora_B shape: {expected_lora_B_shape}, got: {saved_lora_B.shape}")
                        
                        if saved_lora_A.shape == expected_lora_A_shape:
                            module.fg_to_q.lora_A.data = saved_lora_A
                        elif saved_lora_A.T.shape == expected_lora_A_shape:
                            module.fg_to_q.lora_A.data = saved_lora_A.T
                        else:
                            print(f"⚠️ Warning: lora_A dimension mismatch for {prefix}fg_to_q, using random weights...")
                            return False
                        
                        if saved_lora_B.shape == expected_lora_B_shape:
                            module.fg_to_q.lora_B.data = saved_lora_B
                        else:
                            print(f"⚠️ Warning: lora_B dimension mismatch for {prefix}fg_to_q, using random weights...")
                            return False
                        
                        def safe_load_lora(target_module, key_prefix):
                            try:
                                lora_A_key = f"{prefix}{key_prefix}.lora_A"
                                lora_B_key = f"{prefix}{key_prefix}.lora_B"
                                
                                if lora_A_key in lora_state_dict and lora_B_key in lora_state_dict:
                                    saved_A = lora_state_dict[lora_A_key]
                                    saved_B = lora_state_dict[lora_B_key]
                                    
                                    if saved_A.T.shape == target_module.lora_A.shape:
                                        target_module.lora_A.data = saved_A.T
                                    elif saved_A.shape == target_module.lora_A.shape:
                                        target_module.lora_A.data = saved_A
                                    else:
                                        print(f"⚠️ Skipping {key_prefix} due to dimension mismatch")
                                        return False
                                    
                                    if saved_B.shape == target_module.lora_B.shape:
                                        target_module.lora_B.data = saved_B
                                    else:
                                        print(f"⚠️ Skipping {key_prefix} due to dimension mismatch")
                                        return False
                                    
                                    return True
                                return False
                            except Exception as e:
                                print(f"⚠️ Error loading {key_prefix}: {e}")
                                return False
                        
                        safe_load_lora(module.fg_to_k, "fg_to_k")
                        safe_load_lora(module.fg_to_v, "fg_to_v")
                        safe_load_lora(module.fg_to_out, "fg_to_out")
                        
                        safe_load_lora(module.bg_to_q, "bg_to_q")
                        safe_load_lora(module.bg_to_k, "bg_to_k")
                        safe_load_lora(module.bg_to_v, "bg_to_v")
                        safe_load_lora(module.bg_to_out, "bg_to_out")
                        
                        if f"{prefix}inter_layer_gate" in lora_state_dict:
                            module.inter_layer_gate.data = lora_state_dict[f"{prefix}inter_layer_gate"]
                        if f"{prefix}harmonization_weight" in lora_state_dict:
                            module.harmonization_weight.data = lora_state_dict[f"{prefix}harmonization_weight"]
                        
                        if hasattr(module, 'fg_region_processor') and module.fg_region_processor is not None:
                            frp = module.fg_region_processor
                            
                            gate_key = f"{prefix}fg_region_processor.original_feature_gate"
                            if gate_key in lora_state_dict:
                                frp.original_feature_gate.data = lora_state_dict[gate_key]
                            
                            boundary_dict = {k.replace(f"{prefix}fg_region_processor.boundary_predictor.", ""): v 
                                           for k, v in lora_state_dict.items() 
                                           if k.startswith(f"{prefix}fg_region_processor.boundary_predictor.")}
                            if boundary_dict:
                                try:
                                    frp.boundary_predictor.load_state_dict(boundary_dict, strict=False)
                                except Exception as e:
                                    print(f"⚠️ Warning: Failed to load boundary_predictor weights: {e}")
                            
                            unet_dict = {k.replace(f"{prefix}fg_region_processor.unet_encoder.", ""): v 
                                       for k, v in lora_state_dict.items() 
                                       if k.startswith(f"{prefix}fg_region_processor.unet_encoder.")}
                            if unet_dict:
                                try:
                                    frp.unet_encoder.load_state_dict(unet_dict, strict=False)
                                except Exception as e:
                                    print(f"⚠️ Warning: Failed to load unet_encoder weights: {e}")
                            
                        return True
                        
                    except Exception as e:
                        print(f"❌ Error loading LoRA weights for {prefix}: {e}")
                        print(f"   Using randomly initialized weights instead")
                        return False
                        
            for child_name, child_module in module.named_children():
                load_lora_state_dict(child_module, f"{prefix}{child_name}.")
            
            return True
        
        try:
            load_lora_state_dict(self.original_unet, "unet.")
            
            unet_dict = {k.replace("unet_encoder.", ""): v 
                       for k, v in lora_state_dict.items() 
                       if k.startswith("unet_encoder.")}
            if unet_dict:
                self.unet_encoder.load_state_dict(unet_dict, strict=False)
                
            print(f"✅ Loaded LayerDiffusion LoRA weights from {load_path}")
            
        except Exception as e:
            print(f"❌ Failed to load LoRA weights: {e}")
            print("   Model will use randomly initialized LoRA weights")
            return
        
        if load_alpha:
            alpha_path = load_path.replace('.pth', '_alpha.pth')
            if not os.path.exists(alpha_path):
                alpha_path = load_path.replace('.pt', '_alpha.pt')
            
            if os.path.exists(alpha_path):
                try:
                    self.load_alpha_weights(alpha_path)
                except Exception as e:
                    print(f"❌ Failed to load alpha weights: {e}")
                    print("   Using default alpha values (0.0)")
            else:
                print(f"⚠️ Alpha parameters file not found: {alpha_path}")
                print("   Using default alpha values (0.0)")
        else:
            print("   Skipping alpha parameter loading, using default values (0.0)")
    
    def load_alpha_weights(self, alpha_path: str):
        import torch
        
        alpha_state_dict = torch.load(alpha_path, map_location="cpu")
        
        def load_alpha_state_dict(module, prefix=""):
            if hasattr(module, 'spatially_modulated_attn') and module.spatially_modulated_attn is not None:
                if f"{prefix}spatially_modulated_attn.alpha_fg2fg" in alpha_state_dict:
                    module.spatially_modulated_attn.alpha_fg2fg.data = alpha_state_dict[f"{prefix}spatially_modulated_attn.alpha_fg2fg"]
                    module.spatially_modulated_attn.alpha_fg2bg.data = alpha_state_dict[f"{prefix}spatially_modulated_attn.alpha_fg2bg"]
                    module.spatially_modulated_attn.alpha_bg2fg.data = alpha_state_dict[f"{prefix}spatially_modulated_attn.alpha_bg2fg"]
                    module.spatially_modulated_attn.alpha_bg2bg.data = alpha_state_dict[f"{prefix}spatially_modulated_attn.alpha_bg2bg"]
                        
            for child_name, child_module in module.named_children():
                load_alpha_state_dict(child_module, f"{prefix}{child_name}.")
        
        load_alpha_state_dict(self.original_unet, "unet.")
        print(f"Loaded SpatiallyModulated alpha parameters from {alpha_path}")
    
    def reset_alpha_to_default(self):
        def reset_alpha_recursive(module):
            if hasattr(module, 'spatially_modulated_attn') and module.spatially_modulated_attn is not None:
                module.spatially_modulated_attn.alpha_fg2fg.data.zero_()
                module.spatially_modulated_attn.alpha_fg2bg.data.zero_()
                module.spatially_modulated_attn.alpha_bg2fg.data.zero_()
                module.spatially_modulated_attn.alpha_bg2bg.data.zero_()
                        
            for child_module in module.children():
                reset_alpha_recursive(child_module)
        
        reset_alpha_recursive(self.original_unet)
        print("Reset all alpha parameters to default values (0.0)")
    
    def get_alpha_summary(self):
        alpha_values = []
        
        def collect_alpha_values(module):
            if hasattr(module, 'spatially_modulated_attn') and module.spatially_modulated_attn is not None:
                alpha_values.extend([
                    module.spatially_modulated_attn.alpha_fg2fg.item(),
                    module.spatially_modulated_attn.alpha_fg2bg.item(),
                    module.spatially_modulated_attn.alpha_bg2fg.item(),
                    module.spatially_modulated_attn.alpha_bg2bg.item()
                ])
                        
            for child_module in module.children():
                collect_alpha_values(child_module)
        
        collect_alpha_values(self.original_unet)
        
        if alpha_values:
            return {
                'count': len(alpha_values),
                'mean': sum(alpha_values) / len(alpha_values),
                'min': min(alpha_values),
                'max': max(alpha_values),
                'std': (sum((x - sum(alpha_values) / len(alpha_values))**2 for x in alpha_values) / len(alpha_values))**0.5
            }
        else:
            return {'count': 0, 'mean': 0, 'min': 0, 'max': 0, 'std': 0}
    
    def set_mdma_params(self, current_step: int, total_steps: int, mdma_lambda: float = 1.0):
        pass
    
    @property
    def config(self):
        return self.original_unet.config
    
    @property
    def add_embedding(self):
        return self.original_unet.add_embedding
    
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.original_unet = self.original_unet.to(*args, **kwargs)
        self.unet_encoder = self.unet_encoder.to(*args, **kwargs)
        return self
