import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import random
from omini.train_flux.train_aircraft_bg_crop_solar import OminiSolarModel, AircraftSolarDataset, RSSolarContextEncoder
from omini.train_flux.trainer_mask_weighted import get_config, train
from omini.pipeline.flux_omini_solar import solar_transformer_forward, encode_images
from omini.pipeline.flux_omini import specify_lora, apply_rotary_emb, clip_hidden_states

# --- Custom Block Forwards to Avoid Unpacking Error ---
def solar_block_forward_channel_mix(
    self,
    image_hidden_states: list[torch.FloatTensor],
    text_hidden_states: list[torch.FloatTensor],
    tembs: list[torch.FloatTensor],
    adapters: list[str],
    position_embs=None,
    attn_forward=None, # Will be passed as solar_attn_forward_channel_mix
    solar_params: tuple = None,
    **kwargs: dict,
):
    txt_n = len(text_hidden_states)
    img_variables, txt_variables = [], []

    for i, text_h in enumerate(text_hidden_states):
        txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))

    for i, image_h in enumerate(image_hidden_states):
        with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
            img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))

    # Don't unpack solar_params here! Pass directly to attn_forward
    
    # Attention.
    img_attn_output, txt_attn_output = attn_forward(
        self.attn,
        hidden_states=[each[0] for each in img_variables],
        hidden_states2=[each[0] for each in txt_variables],
        position_embs=position_embs,
        adapters=adapters,
        solar_params=solar_params, # Pass tuple directly
        **kwargs,
    )

    text_out = []
    for i in range(len(text_hidden_states)):
        _, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
        text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
        norm_h = (
            self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
        )
        text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
        text_out.append(clip_hidden_states(text_h))

    image_out = []
    for i in range(len(image_hidden_states)):
        _, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
        image_h = (
            image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
        ).to(image_hidden_states[i].dtype)
        norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
        with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
            image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
        image_out.append(clip_hidden_states(image_h))
    return image_out, text_out


def solar_single_block_forward_channel_mix(
    self,
    hidden_states: list[torch.FloatTensor],
    tembs: list[torch.FloatTensor],
    adapters: list[str],
    position_embs=None,
    attn_forward=None,
    solar_params: tuple = None,
    **kwargs: dict,
):
    mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
    hidden_state_norm = []
    for i, hidden_state in enumerate(hidden_states):
        with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
            h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
            mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
        hidden_state_norm.append(h_norm)

    # Don't unpack solar_params here! Pass directly to attn_forward

    attn_outputs = attn_forward(
        self.attn, 
        hidden_state_norm, 
        adapters, 
        position_embs=position_embs, 
        solar_params=solar_params, # Pass tuple directly
        **kwargs
    )

    h_out = []
    for i in range(len(hidden_states)):
        with specify_lora((self.proj_out,), adapters[i]):
            h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
            h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
            h_out.append(clip_hidden_states(h))

    return h_out

# --- Custom Attention Forward with Gated Non-Linear Modulation ---
def solar_attn_forward_channel_mix(
    attn,
    hidden_states: list[torch.FloatTensor],
    adapters: list[str],
    hidden_states2: list[torch.FloatTensor] = [],
    position_embs: list[torch.Tensor] = None,
    group_mask: torch.Tensor = None,
    cache_mode: str = None,
    to_cache: list[torch.Tensor] = None,
    cache_storage: list[torch.Tensor] = None,
    solar_params: tuple = None, # (scale, shift, gate)
    **kwargs: dict,
) -> torch.FloatTensor:
    bs, _, _ = hidden_states[0].shape
    h2_n = len(hidden_states2)

    queries, keys, values = [], [], []

    # Text branch (Standard)
    for i, hidden_state in enumerate(hidden_states2):
        query = attn.add_q_proj(hidden_state)
        key = attn.add_k_proj(hidden_state)
        value = attn.add_v_proj(hidden_state)
        head_dim = key.shape[-1] // attn.heads
        reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
        query, key, value = map(reshape_fn, (query, key, value))
        query, key = attn.norm_added_q(query), attn.norm_added_k(key)
        queries.append(query)
        keys.append(key)
        values.append(value)

    # Image branch (With Gated Non-Linear Modulation)
    for i, hidden_state in enumerate(hidden_states):
        with specify_lora((attn.to_q, attn.to_k, attn.to_v), adapters[i + h2_n]):
            query = attn.to_q(hidden_state)
            key = attn.to_k(hidden_state)
            value = attn.to_v(hidden_state) # [B, L, inner_dim]
            
            # --- Apply Gated Non-Linear Modulation ---
            # Formula: V_new = V + gate * tanh(V * scale + shift)
            if solar_params is not None:
                # Unpack params: [B, 1, D]
                scale, shift, gate = solar_params
                
                # Reshape for broadcasting over heads if necessary?
                # value is [B, L, D] here (before reshape to heads)
                # solar params are [B, 1, D]
                # We can apply modulation directly on the flat inner_dim
                
                # 1. Gate (Sigmoid)
                gate_act = torch.sigmoid(gate)
                
                # 2. Modulation (Tanh)
                # Note: The formula in image says `scale * V_old + shift`. 
                # Element-wise multiplication.
                modulation = torch.tanh(value * scale + shift)
                
                # 3. Gated Addition
                value = value + gate_act * modulation

        head_dim = key.shape[-1] // attn.heads
        reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)

        query, key, value = map(reshape_fn, (query, key, value))
        query, key = attn.norm_q(query), attn.norm_k(key)

        queries.append(query)
        keys.append(key)
        values.append(value)

    # Apply rotary embedding
    if position_embs is not None:

        queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
        keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]

    if cache_mode == "write":
        for i, (k, v) in enumerate(zip(keys, values)):
            if to_cache[i]:
                cache_storage[attn.cache_idx][0].append(k)
                cache_storage[attn.cache_idx][1].append(v)

    attn_outputs = []
    for i, query in enumerate(queries):
        keys_, values_ = [], []
        for j, (k, v) in enumerate(zip(keys, values)):
            if (group_mask is not None) and not (group_mask[i][j].item()):
                continue
            keys_.append(k)
            values_.append(v)
        if cache_mode == "read":
            keys_.extend(cache_storage[attn.cache_idx][0])
            values_.extend(cache_storage[attn.cache_idx][1])
            
        attn_output = F.scaled_dot_product_attention(
            query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
        ).to(query.dtype)
        attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
        attn_outputs.append(attn_output)

    h_out, h2_out = [], []
    for i, hidden_state in enumerate(hidden_states2):
        h2_out.append(attn.to_add_out(attn_outputs[i]))

    for i, hidden_state in enumerate(hidden_states):
        h = attn_outputs[i + h2_n]
        if getattr(attn, "to_out", None) is not None:
            with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
                h = attn.to_out[0](h)
        h_out.append(h)

    return (h_out, h2_out) if h2_n else h_out


class OminiSolarChannelMixModel(OminiSolarModel):
    def __init__(self, *args, **kwargs):
        # We need to initialize parent, but we will override solar_projectors
        super().__init__(*args, **kwargs)
        
        # Override Projectors for Gated Non-Linear Modulation
        # Output: inner_dim * 3 (scale, shift, gate)
        
        inner_dim = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
        self.cm_inner_dim = inner_dim
        
        # Re-initialize projectors
        total_layers = len(self.solar_projectors)
        out_dim = inner_dim * 3
        
        self.solar_projectors = nn.ModuleList([
            nn.Sequential(
                nn.SiLU(), # Non-linearity as per diagram
                nn.Linear(1024, out_dim), 
            ) for _ in range(total_layers)
        ]).to(self.device).to(torch.float32)
        
        # Zero-initialize the linear layer
        for proj in self.solar_projectors:
            nn.init.zeros_(proj[1].weight)
            nn.init.zeros_(proj[1].bias)
            
        print(f"Initialized Gated Non-Linear Projectors: 1024 -> {out_dim} (scale, shift, gate)")

    def training_step(self, batch, batch_idx):
        # ... (Most logic is same, just how we process solar_params differs) ...
        
        imgs, prompts = batch["image"], batch["description"]
        target_mask = batch.get("target_mask", None)
        
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))
            
        with torch.no_grad():
            imgs = imgs.to(self.device)
            if target_mask is not None:
                target_mask = target_mask.to(self.device)
            
            x_0, img_ids = encode_images(self.flux_pipe, imgs)
            x_0 = x_0.to(self.device)
            img_ids = img_ids.to(self.device)
            
            prompt_embeds, pooled_prompt_embeds, text_ids = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
            )
            prompt_embeds = prompt_embeds.to(self.device)
            pooled_prompt_embeds = pooled_prompt_embeds.to(self.device)
            text_ids = text_ids.to(self.device)
            
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            
            condition_latents, condition_ids = [], []
            bg_latents_for_solar = None
            
            for idx, (cond, p_delta, p_scale, latent_mask) in enumerate(zip(
                conditions, position_deltas, position_scales, latent_masks
            )):
                cond = cond.to(self.device)
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                c_latents = c_latents.to(self.device)
                c_ids = c_ids.to(self.device)
                
                if idx == 1:
                    bg_latents_for_solar = c_latents
                
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                
                condition_latents.append(c_latents)
                condition_ids.append(c_ids)

            guidance = torch.ones_like(t).to(self.device) if self.transformer.config.guidance_embeds else None

        # --- Solar Optimization Logic (Gated Non-Linear) ---
        solar_params_list = []
        if bg_latents_for_solar is not None and target_mask is not None:
            B, L, C = bg_latents_for_solar.shape
            H = int((L ** 0.5))
            W = H
            bg_spatial = bg_latents_for_solar.transpose(1, 2).view(B, C, H, W).to(torch.float32)
            target_mask_f32 = target_mask.to(torch.float32)
            
            context_vector = self.solar_encoder(bg_spatial, target_mask_f32)
            
            D = self.cm_inner_dim
            
            for proj in self.solar_projectors:
                # params shape: [B, D*3]
                params = proj(context_vector)
                
                # Split: [B, D], [B, D], [B, D]
                scale, shift, gate = params.chunk(3, dim=-1)
                
                # Add singleton dimension: [B, 1, D]
                scale = scale.unsqueeze(1).to(self.dtype)
                shift = shift.unsqueeze(1).to(self.dtype)
                gate = gate.unsqueeze(1).to(self.dtype)
                
                solar_params_list.append((scale, shift, gate))
        else:
            # Fallback
            pass

        # --- Forward Pass ---
        branch_n = 2 + len(conditions)
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        if not self.model_config.get("inter_condition_attention", False):
            group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions), device=self.device))
        if self.model_config.get("independent_condition", False):
            group_mask[2:, :2] = False

        transformer_out = solar_transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            timesteps=[t, t] + [torch.zeros_like(t)] * len(conditions),
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            adapters=self.adapter_names,
            return_dict=False,
            group_mask=group_mask,
            solar_params_list=solar_params_list,
            attn_forward=solar_attn_forward_channel_mix, # Inject Custom Attention
            block_forward=solar_block_forward_channel_mix, # Inject Custom Block Forward
            single_block_forward=solar_single_block_forward_channel_mix, # Inject Custom Single Block Forward
        )
        pred = transformer_out[0]

        target = x_1 - x_0
        loss = F.mse_loss(pred, target)
        
        self.log_loss = loss.item()
        return loss

def main():
    config = get_config()
    training_config = config["train"]
    dataset_config = training_config["dataset"]
    
    print("=" * 70)
    print("Aircraft Solar Training (Level 1: Gated Non-Linear Modulation)")
    print("V_new = V + sigmoid(gate) * tanh(V * scale + shift)")
    print("=" * 70)
    
    dataset = AircraftSolarDataset(
        dataset_root=dataset_config["dataset_root"],
        condition_size=tuple(dataset_config["condition_size"]),
        target_size=tuple(dataset_config["target_size"]),
        drop_text_prob=0.1,
    )
    
    model = OminiSolarChannelMixModel(
        flux_pipe_id=config["flux_path"],
        lora_path=None,
        lora_config=training_config.get("lora_config", None),
        device="cuda",
        dtype=torch.bfloat16 if config["dtype"] == "bfloat16" else torch.float32,
        model_config=config.get("model", {}),
        adapter_names=[None, None, "subject", "background"], 
        optimizer_config=training_config.get("optimizer", None),
        gradient_checkpointing=training_config.get("gradient_checkpointing", False),
    )
    
    train(dataset, model, config)

if __name__ == "__main__":
    main()