import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
from omini.train_flux.train_aircraft_bg_crop_solar import AircraftSolarDataset
from omini.train_flux.trainer_mask_weighted import get_config, train, OminiModel
from omini.pipeline.flux_omini_solar import solar_transformer_forward, encode_images

def frequency_split_latents(latents, kernel_size=5):
    """
    Decompose latents into Low-Frequency and High-Frequency components.
    
    Args:
        latents: [B, L, C] Packed latents
        kernel_size: Kernel size for Average Pooling (Low Pass Filter)
        
    Returns:
        low_flat: [B, L, C] Low-frequency component (Style/Lighting)
        high_flat: [B, L, C] High-frequency component (Structure/Edges)
    """
    B, L, C = latents.shape
    H = int(L**0.5)
    W = H
    
    # Reshape to spatial: [B, C, H, W]
    x = latents.transpose(1, 2).view(B, C, H, W)
    
    # Low Pass: AvgPool2d to simulate Gaussian Blur
    # Padding ensures output size matches input
    pad = kernel_size // 2
    low = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad, count_include_pad=False)
    
    # High Pass: Original - Low
    high = x - low
    
    # Flatten back to [B, L, C]
    low_flat = low.flatten(2).transpose(1, 2)
    high_flat = high.flatten(2).transpose(1, 2)
    
    return low_flat, high_flat

class OminiFrequencyModel(OminiModel):
    def training_step(self, batch, batch_idx):
        imgs, prompts = batch["image"], batch["description"]
        
        # Conditions
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))
            
        with torch.no_grad():
            imgs = imgs.to(self.device)
            
            # Encode Images (Target)
            x_0, img_ids = encode_images(self.flux_pipe, imgs)
            x_0 = x_0.to(self.device)
            img_ids = img_ids.to(self.device)
            
            # Encode Prompts
            prompt_embeds, pooled_prompt_embeds, text_ids = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
            )
            prompt_embeds = prompt_embeds.to(self.device)
            pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
            text_ids = text_ids.to(self.device)
            
            # Noise
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            
            # Process Conditions with Frequency Splitting
            condition_latents = []
            condition_ids = []
            
            # We need to construct the adapter list dynamically for this forward pass
            # Because we are splitting one condition into two branches
            # Base adapters: [Img, Txt] -> [None, None]
            used_adapters = [None, None] 
            
            # Iterate through conditions
            # Assuming: 
            #   idx 0 = Subject (Needs Splitting)
            #   idx 1 = Background (Keep as is)
            
            for idx, (cond, p_delta, p_scale, latent_mask) in enumerate(zip(
                conditions, position_deltas, position_scales, latent_masks
            )):
                cond = cond.to(self.device)
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                c_latents = c_latents.to(self.device)
                c_ids = c_ids.to(self.device)
                
                # Apply Position Delta/Scale (Common for both splits)
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                
                # Frequency Split Logic
                if idx == 0: # Subject
                    # Split into Low (Style) and High (Structure)
                    low_latents, high_latents = frequency_split_latents(c_latents)
                    
                    # Branch 1: Low Frequency -> Adapter "subject_low"
                    condition_latents.append(low_latents)
                    condition_ids.append(c_ids) # Share IDs (Same position)
                    used_adapters.append("subject_low")
                    
                    # Branch 2: High Frequency -> Adapter "subject_high"
                    condition_latents.append(high_latents)
                    condition_ids.append(c_ids) # Share IDs
                    used_adapters.append("subject_high")
                    
                else: # Background (idx 1)
                    # Normal processing
                    condition_latents.append(c_latents)
                    condition_ids.append(c_ids)
                    used_adapters.append("background")

            guidance = torch.ones_like(t).to(self.device) if self.transformer.config.guidance_embeds else None

        # --- Forward Pass ---
        # Note: branch_n is now 2 + 3 = 5 (Img, Txt, Sub_Low, Sub_High, Bg)
        branch_n = len(used_adapters)
        
        # Group Mask Construction
        # We need to ensure attention flow is correct.
        # Img can attend to all conditions.
        # Conditions typically don't attend to each other unless inter_condition_attention is True.
        
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        if not self.model_config.get("inter_condition_attention", False):
            # Disable attention between condition branches (indices 2 to end)
            group_mask[2:, 2:] = torch.diag(torch.tensor([1] * (branch_n - 2), device=self.device))
            
        if self.model_config.get("independent_condition", False):
            # Conditions don't see Image/Text? (Usually False)
            group_mask[2:, :2] = False

        # We can use the solar_transformer_forward (it's compatible with standard adapters if solar_params is None)
        # Or standard transformer_forward if available. 
        # Since we imported solar_transformer_forward, let's use it with None params.
        
        transformer_out = solar_transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            timesteps=[t, t] + [torch.zeros_like(t)] * len(condition_latents),
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            adapters=used_adapters, # Pass our dynamic adapter list
            return_dict=False,
            group_mask=group_mask,
            solar_params_list=None, # No solar modulation in this version
        )
        pred = transformer_out[0]

        # --- Loss Calculation ---
        target = x_1 - x_0
        loss = F.mse_loss(pred, target)
        
        self.log_loss = loss.item()
        return loss

    def configure_optimizers(self):
        params_high = []
        params_low_bg = []
        
        for n, p in self.named_parameters():
            if not p.requires_grad:
                continue
            if "subject_high" in n:
                params_high.append(p)
            else:
                params_low_bg.append(p)
                
        count_high = len(params_high)
        count_low = len(params_low_bg)
        
        print(f"Asymmetric Training Config:")
        print(f"  - Low-Freq/BG Params: {count_low} tensors (LR: 1.0x)")
        print(f"  - High-Freq Params:   {count_high} tensors (LR: 0.3x) -> Structure Preserved")
        
        # Get base learning rate
        # Handle Prodigy vs AdamW config structure differences
        if "lr" in self.optimizer_config["params"]:
            base_lr = self.optimizer_config["params"]["lr"]
        else:
            base_lr = 1e-4 # Default fallback
            
        # Create Parameter Groups (Uniform LR)
        # We still group them just for clarity/logging, but pass same/no LR
        param_groups = [
            {"params": params_low_bg},
            {"params": params_high}
        ]
        
        # Save trainable params for reference (e.g. gradient logging callback)
        self.trainable_params = params_low_bg + params_high
        
        opt_name = self.optimizer_config.get("type", "AdamW")
        opt_params = {k: v for k, v in self.optimizer_config["params"].items() if k != "lr"}
        
        if opt_name == "Prodigy":
            import prodigyopt
            # Prodigy manages LR internally
            optimizer = prodigyopt.Prodigy(param_groups, lr=base_lr, **opt_params)
        else:
            # AdamW needs explicit LR if not in opt_params
            optimizer = torch.optim.AdamW(param_groups, lr=base_lr, **opt_params)
            
        return optimizer

def main():
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Frequency-Gated LoRA Training (Innovation 2)")
    print("Structure (High-Freq) vs Style (Low-Freq) Decoupling")
    print("=" * 70)
    
    dataset = AircraftSolarDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.1,
    )
    
    # Define Adapters:
    # 0: None (Image)
    # 1: None (Text)
    # 2: "subject_low" (Low Freq Subject)
    # 3: "subject_high" (High Freq Subject)
    # 4: "background" (Background)
    adapter_names = [None, None, "subject_low", "subject_high", "background"]
    
    model = OminiFrequencyModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=adapter_names, 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    
    train(dataset, model, config)

if __name__ == "__main__":
    main()