import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel, AircraftSolarDataset, RSSolarContextEncoder
from omini.train_flux.trainer_mask_weighted import get_config, train
from omini.pipeline.flux_omini_solar import solar_transformer_forward, encode_images

class OminiSolarSemanticModel(OminiSolarModel):
    def training_step(self, batch, batch_idx):
        imgs, prompts = batch["image"], batch["description"]
        target_mask = batch.get("target_mask", None)
        
        # Conditions
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))
            
        # Assuming index 1 is background (based on AircraftBackgroundCropDataset)
        # condition_0: subject, condition_1: background
        background_img = conditions[1] if len(conditions) > 1 else None

        with torch.no_grad():
            imgs = imgs.to(self.device)
            if target_mask is not None:
                target_mask = target_mask.to(self.device)
            
            # Encode Images
            x_0, img_ids = encode_images(self.flux_pipe, imgs)
            x_0 = x_0.to(self.device)
            img_ids = img_ids.to(self.device)
            
            # Encode Prompts
            prompt_embeds, pooled_prompt_embeds, text_ids = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
            )
            prompt_embeds = prompt_embeds.to(self.device)
            pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
            text_ids = text_ids.to(self.device)
            
            # Noise
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            
            # Conditions
            condition_latents, condition_ids = [], []
            bg_latents_for_solar = None
            
            for idx, (cond, p_delta, p_scale, latent_mask) in enumerate(zip(
                conditions, position_deltas, position_scales, latent_masks
            )):
                cond = cond.to(self.device)
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                c_latents = c_latents.to(self.device)
                c_ids = c_ids.to(self.device)
                
                # Save background latents for solar encoder (before positional modifications)
                # Assuming condition_1 is background
                if idx == 1:
                    bg_latents_for_solar = c_latents
                
                # ... (Position scaling and delta logic same as original) ...
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                
                condition_latents.append(c_latents)
                condition_ids.append(c_ids)

            guidance = torch.ones_like(t).to(self.device) if self.transformer.config.guidance_embeds else None

        # --- Solar Optimization Logic (Forward Pass 1: Get Params from Background) ---
        solar_params_list = []
        context_vector_bg = None
        
        # We need these shapes for later reshaping
        B, L, C = x_0.shape # C is packed channels (64)
        H_latent = int((L ** 0.5))
        W_latent = H_latent
        
        if bg_latents_for_solar is not None and target_mask is not None:
            # 1. Prepare Inputs for Solar Encoder
            bg_spatial = bg_latents_for_solar.transpose(1, 2).view(B, C, H_latent, W_latent)
            
            # Ensure FP32 for solar network
            bg_spatial = bg_spatial.to(torch.float32)
            target_mask_f32 = target_mask.to(torch.float32)
            
            # 2. Run Encoder to get Background Context
            # We use this as the "Target" semantic representation
            context_vector_bg = self.solar_encoder(bg_spatial, target_mask_f32) # [B, 1024]
            
            # 3. Run Projectors (to modulate the current diffusion step)
            # Note: We still use the background context to modulate the generation, 
            # because we want the generation to match the background's lighting.
            for proj in self.solar_projectors:
                params = proj(context_vector_bg)
                scale, shift = params.chunk(2, dim=1)
                scale = scale.unsqueeze(1).to(self.dtype)
                shift = shift.unsqueeze(1).to(self.dtype)
                solar_params_list.append((scale, shift))
        else:
            # Fallback
            inner_dim = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
            zero = torch.zeros((imgs.shape[0], 1, inner_dim), device=self.device, dtype=self.dtype)
            solar_params_list = [(zero, zero)] * len(self.solar_projectors)

        # --- Forward Pass with Solar Params ---
        branch_n = 2 + len(conditions)
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        if not self.model_config.get("inter_condition_attention", False):
            group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions), device=self.device))
        if self.model_config.get("independent_condition", False):
            group_mask[2:, :2] = False

        transformer_out = solar_transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            timesteps=[t, t] + [torch.zeros_like(t)] * len(conditions),
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            adapters=self.adapter_names,
            return_dict=False,
            group_mask=group_mask,
            solar_params_list=solar_params_list,
        )
        pred = transformer_out[0] # v_pred

        # --- Loss Calculation ---
        target = x_1 - x_0 # v_target
        loss_diff = F.mse_loss(pred, target)
        
        # --- Innovation 2: Differentiable Semantic Reward / Guidance Loss ---
        loss_semantic = torch.tensor(0.0, device=self.device)
        
        if context_vector_bg is not None and target_mask is not None:
            # 1. Calculate Predicted Clean Latent z_0_hat
            # Flow Matching: z_t = z_0 + t * v
            # So z_0 = z_t - t * v
            # Here pred is v_pred
            
            t_expand = t.view(B, 1, 1).to(self.dtype)
            z_0_hat = x_t - t_expand * pred
            
            # 2. Reshape to Spatial for Encoder
            # [B, L, C] -> [B, C, H, W]
            z_0_hat_spatial = z_0_hat.transpose(1, 2).view(B, C, H_latent, W_latent).to(torch.float32)
            
            # 3. Encode Predicted Clean Latent
            # We want the solar context of the predicted image to match the background
            # CRITICAL FIX: For the predicted image, we should NOT mask out the object.
            # We want the encoder to see the generated object and verify its lighting.
            # So we pass a "no-hole" mask (all zeros) to indicate the whole image is valid.
            pred_mask = torch.zeros_like(target_mask_f32)
            context_vector_pred = self.solar_encoder(z_0_hat_spatial, pred_mask)
            
            # 4. Calculate Semantic Loss
            # MSE between predicted context and background context
            loss_semantic = F.mse_loss(context_vector_pred, context_vector_bg)
            
        # Total Loss
        # Get lambda from config or default to 0.1
        lambda_semantic = self.model_config.get("lambda_semantic", 0.1)
        
        total_loss = loss_diff + lambda_semantic * loss_semantic
        
        self.log_loss = total_loss.item()
        
        # Optional: Log components
        # print(f"Diff: {loss_diff.item():.4f}, Semantic: {loss_semantic.item():.4f}")
        
        return total_loss

def main():
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Solar Semantic Training (Innovation 2)")
    print("Network: RSSolarContextEncoder -> V-Modulation + Semantic Loss")
    print("=" * 70)
    
    dataset = AircraftSolarDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.1,
    )
    
    model = OminiSolarSemanticModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    
    train(dataset, model, config)

if __name__ == "__main__":
    main()