import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel, AircraftSolarDataset
from omini.train_flux.trainer_mask_weighted import get_config, train
from omini.pipeline.flux_omini_solar import solar_transformer_forward, encode_images

def frequency_split_latents(latents, kernel_size=5):
    """
    Decompose latents into Low-Frequency and High-Frequency components.
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Reshape to spatial: [B, C, H, W]
    x = latents.transpose(1, 2).view(B, C, H, W)
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    high = x - low
    
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    
    return low_flat, high_flat

class OminiSolarLowOriginalModel(OminiSolarModel):
    def training_step(self, batch, batch_idx):
        imgs, prompts = batch["image"], batch["description"]
        target_mask = batch.get("target_mask", None)
        
        # Conditions
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))
            
        with torch.no_grad():
            imgs = imgs.to(self.device)
            if target_mask is not None:
                target_mask = target_mask.to(self.device)
            
            # Encode Images
            x_0, img_ids = encode_images(self.flux_pipe, imgs)
            x_0 = x_0.to(self.device)
            img_ids = img_ids.to(self.device)
            
            # Encode Prompts
            prompt_embeds, pooled_prompt_embeds, text_ids = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
            )
            prompt_embeds = prompt_embeds.to(self.device)
            pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
            text_ids = text_ids.to(self.device)
            
            # Noise
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            
            # --- 1. Process Conditions with Low/Original Splitting ---
            condition_latents = []
            condition_ids = []
            bg_latents_for_solar = None
            used_adapters = [None, None] # Img, Txt
            
            for idx, (cond, p_delta, p_scale, latent_mask) in enumerate(zip(
                conditions, position_deltas, position_scales, latent_masks
            )):
                cond = cond.to(self.device)
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                c_latents = c_latents.to(self.device)
                c_ids = c_ids.to(self.device)
                
                # Save background latents for solar encoder
                if idx == 1:
                    bg_latents_for_solar = c_latents
                
                # Apply Position Delta/Scale
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                
                # Modified Logic
                if idx == 0: # Subject
                    low_latents, _ = frequency_split_latents(c_latents)
                    
                    # Branch 1: Low Frequency
                    condition_latents.append(low_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("subject_low")
                    
                    # Branch 2: Original Image (Replaces High Freq)
                    condition_latents.append(c_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("subject_original")
                    
                else: # Background (idx 1)
                    condition_latents.append(c_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("background")

            guidance = torch.ones_like(t).to(self.device) if self.transformer.config.guidance_embeds else None

        # --- 2. Solar Optimization Logic ---
        solar_params_list = []
        if bg_latents_for_solar is not None and target_mask is not None:
            B, L, C = bg_latents_for_solar.shape
            H = int((L ** 0.5))
            W = H
            
            bg_spatial = bg_latents_for_solar.transpose(1, 2).view(B, C, H, W)
            bg_spatial = bg_spatial.to(torch.float32)
            target_mask_f32 = target_mask.to(torch.float32)
            
            context_vector = self.solar_encoder(bg_spatial, target_mask_f32)
            
            for proj in self.solar_projectors:
                params = proj(context_vector)
                scale, shift = params.chunk(2, dim=1)
                scale = scale.unsqueeze(1).to(self.dtype)
                shift = shift.unsqueeze(1).to(self.dtype)
                solar_params_list.append((scale, shift))
        else:
            inner_dim = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
            zero = torch.zeros((imgs.shape[0], 1, inner_dim), device=self.device, dtype=self.dtype)
            solar_params_list = [(zero, zero)] * len(self.solar_projectors)

        # --- 3. Forward Pass (Unified) ---
        branch_n = len(used_adapters)
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        if not self.model_config.get("inter_condition_attention", False):
            # Disable inter-condition attention
            group_mask[2:, 2:] = torch.diag(torch.tensor([1] * (branch_n - 2), device=self.device))
            
        if self.model_config.get("independent_condition", False):
            group_mask[2:, :2] = False

        transformer_out = solar_transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            timesteps=[t, t] + [torch.zeros_like(t)] * len(condition_latents),
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            adapters=used_adapters,
            return_dict=False,
            group_mask=group_mask,
            solar_params_list=solar_params_list,
        )
        pred = transformer_out[0]

        # --- Loss Calculation ---
        target = x_1 - x_0
        loss = F.mse_loss(pred, target)
        
        self.log_loss = loss.item()
        return loss

    def configure_optimizers(self):
        params_to_optimize = []
        
        # 1. LoRA
        if self.lora_layers:
            params_to_optimize.extend(self.lora_layers)
            
        # 2. Solar Components
        params_to_optimize.extend(self.solar_encoder.parameters())
        params_to_optimize.extend(self.solar_projectors.parameters())
        
        self.trainable_params = params_to_optimize
        
        # Optimizer Config
        opt_config = self.optimizer_config
        opt_name = opt_config.get("type", "AdamW")
        opt_params = {k: v for k, v in opt_config["params"].items() if k != "lr"}
        base_lr = opt_config["params"].get("lr", 1e-4)
        
        if opt_name == "Prodigy":
            import prodigyopt
            optimizer = prodigyopt.Prodigy(params_to_optimize, lr=base_lr, **opt_params)
        else:
            optimizer = torch.optim.AdamW(params_to_optimize, lr=base_lr, **opt_params)
            
        return optimizer

def main():
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Solar + Low-Freq + Original Training")
    print("Combining Solar Context Encoder & Low/Original LoRA")
    print("=" * 70)
    
    dataset = AircraftSolarDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.1,
    )
    
    # Define Adapters:
    # 0: None (Image)
    # 1: None (Text)
    # 2: "subject_low"
    # 3: "subject_original"
    # 4: "background"
    adapter_names = [None, None, "subject_low", "subject_original", "background"]
    
    model = OminiSolarLowOriginalModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=adapter_names, 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    
    train(dataset, model, config)

if __name__ == "__main__":
    main()