import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
from omini.train_flux.train_aircraft_bg_crop_solar import AircraftSolarDataset
from omini.train_flux.trainer_mask_weighted import get_config, train, OminiModel
from omini.pipeline.flux_omini_solar import solar_transformer_forward, encode_images

def frequency_split_latents(latents, kernel_size=5):
    """
    Decompose latents into Low-Frequency and High-Frequency components.
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Reshape to spatial: [B, C, H, W]
    x = latents.transpose(1, 2).view(B, C, H, W)
    
    # Low Pass: AvgPool2d to simulate Gaussian Blur
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    
    # High Pass: Original - Low
    high = x - low
    
    # Flatten back to [B, L, C]
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    
    return low_flat, high_flat

class OminiLowOriginalModel(OminiModel):
    def training_step(self, batch, batch_idx):
        imgs, prompts = batch["image"], batch["description"]
        
        # Conditions
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))
            
        with torch.no_grad():
            imgs = imgs.to(self.device)
            
            # Encode Images (Target)
            x_0, img_ids = encode_images(self.flux_pipe, imgs)
            x_0 = x_0.to(self.device)
            img_ids = img_ids.to(self.device)
            
            # Encode Prompts
            prompt_embeds, pooled_prompt_embeds, text_ids = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
            )
            prompt_embeds = prompt_embeds.to(self.device)
            pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
            text_ids = text_ids.to(self.device)
            
            # Noise
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            
            # Process Conditions with Low/Original Splitting
            condition_latents = []
            condition_ids = []
            
            # Base adapters: [Img, Txt] -> [None, None]
            used_adapters = [None, None] 
            
            for idx, (cond, p_delta, p_scale, latent_mask) in enumerate(zip(
                conditions, position_deltas, position_scales, latent_masks
            )):
                cond = cond.to(self.device)
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                c_latents = c_latents.to(self.device)
                c_ids = c_ids.to(self.device)
                
                # Apply Position Delta/Scale
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                
                # Modified Logic: Low Freq + Original
                if idx == 0: # Subject
                    low_latents, _ = frequency_split_latents(c_latents) # Only need low
                    
                    # Branch 1: Low Frequency -> Adapter "subject_low"
                    condition_latents.append(low_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("subject_low")
                    
                    # Branch 2: Original Image -> Adapter "subject_original" (Replaces High Freq)
                    condition_latents.append(c_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("subject_original")
                    
                else: # Background (idx 1)
                    condition_latents.append(c_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("background")

            guidance = torch.ones_like(t).to(self.device) if self.transformer.config.guidance_embeds else None

        # --- Forward Pass ---
        branch_n = len(used_adapters)
        
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        if not self.model_config.get("inter_condition_attention", False):
            # Disable attention between condition branches
            group_mask[2:, 2:] = torch.diag(torch.tensor([1] * (branch_n - 2), device=self.device))
            
        if self.model_config.get("independent_condition", False):
            group_mask[2:, :2] = False

        transformer_out = solar_transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            timesteps=[t, t] + [torch.zeros_like(t)] * len(condition_latents),
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            adapters=used_adapters,
            return_dict=False,
            group_mask=group_mask,
            solar_params_list=None,
        )
        pred = transformer_out[0]

        # --- Loss Calculation ---
        target = x_1 - x_0
        loss = F.mse_loss(pred, target)
        
        self.log_loss = loss.item()
        return loss

    def configure_optimizers(self):
        params_original = []
        params_low_bg = []
        
        for n, p in self.named_parameters():
            if not p.requires_grad:
                continue
            if "subject_original" in n:
                params_original.append(p)
            else:
                params_low_bg.append(p)
                
        # Optimizer Config
        opt_config = self.optimizer_config
        opt_name = opt_config.get("type", "AdamW")
        opt_params = {k: v for k, v in opt_config["params"].items() if k != "lr"}
        base_lr = opt_config["params"].get("lr", 1e-4)
        
        # Group params
        param_groups = [
            {"params": params_low_bg},
            {"params": params_original}
        ]
        
        self.trainable_params = params_low_bg + params_original
        
        if opt_name == "Prodigy":
            import prodigyopt
            optimizer = prodigyopt.Prodigy(param_groups, lr=base_lr, **opt_params)
        else:
            optimizer = torch.optim.AdamW(param_groups, lr=base_lr, **opt_params)
            
        return optimizer

def main():
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Low-Freq + Original Training")
    print("Structure (Original) + Style (Low-Freq)")
    print("=" * 70)
    
    dataset = AircraftSolarDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.1,
    )
    
    # Define Adapters:
    # 0: None (Image)
    # 1: None (Text)
    # 2: "subject_low" (Low Freq Subject)
    # 3: "subject_original" (Original Subject)
    # 4: "background" (Background)
    adapter_names = [None, None, "subject_low", "subject_original", "background"]
    
    model = OminiLowOriginalModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=adapter_names, 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    
    train(dataset, model, config)

if __name__ == "__main__":
    main()