#!/usr/bin/env python3
"""
Simplified training script that reads all configuration from a YAML file.
Usage: python train_simple.py --config configs/train_config_example.yaml

All training parameters (model paths, hyperparameters, concepts, etc.) are specified in the YAML config.
No additional command-line arguments are needed.
"""

import argparse
import json
import os
import random
from typing import Any, Callable, Dict, List, Optional, Union
from torchvision.transforms.functional import to_pil_image
import copy
from pathlib import Path
from functools import partial

import pandas as pd
import torch
import torch.nn as nn
import yaml


# Try to import wandb, but handle gracefully if not available
try:
    import wandb
    WANDB_AVAILABLE = hasattr(wandb, 'init')
except ImportError:
    wandb = None
    WANDB_AVAILABLE = False

from diffusers import FluxPipeline
from accelerate import cpu_offload
from tools.prompt_process import encode_prompt, _get_clip_prompt_embeds
from tools.scheduler_process import FlowMatchEulerDiscreteScheduler
from torchvision.transforms.functional import to_tensor
from generate_bare_flux import retrieve_timesteps, inference_latent_sample, generate_one_image_from_prompt
from accelerate import Accelerator
from tools.scheduler_process import FlowMatchEulerDiscreteScheduler
from utils.esd_utils import latent_sample, predict_noise, flux_pack_latents, _prepare_latent_image_ids
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
from accelerate.utils import ProjectConfiguration, set_seed as hf_set_seed
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm

from diffusers import FluxPipeline
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    FluxPipeline,
    FluxTransformer2DModel,
)

from hyper_lora import HyperLoRALinear, HypernetworkManager, inject_hyper_lora
from ldm.models.diffusion.ddimcopy import DDIMSampler
from utils.sampling import sample_model
from utils import print_trainable_parameters
from diffusers.utils.torch_utils import randn_tensor


class Cache:
    """
    Unified cache for all text embeddings and latents.

    Stores:
    - Target prompts: latents + text embeddings (prompt_embeds, pooled_prompt_embeds, text_ids)
    - Mapping prompts: text embeddings for mapping concepts (what targets should map to)
    - Diagnostic prompts: text embeddings only (no latents)
    - Unconditional: text embeddings for empty prompt

    Uses Flux's built-in pooled_prompt_embeds (768-dim from CLIP text_encoder_one) for HyperLoRA context,
    eliminating the need for a separate CLIP model.
    """

    def __init__(
        self,
        target_prompts: List[str],
        mapping_prompts: List[str],
        diagnostic_prompts: List[str],
        transformer,
        noise_scheduler,
        text_encoders: List,
        tokenizers: List,
        device: torch.device,
        max_ddim_steps: int = 28,
        height: int = 512,
        width: int = 512,
        num_channels_latents: int = 16,
        seed: int = 42,
        weight_dtype: torch.dtype = torch.bfloat16,
        guidance: float = 3.0,
        cache_path: str = None,
    ):
        if cache_path and not os.path.exists(cache_path):
            raise ValueError("Cache path doesn't exist!")

        if max_ddim_steps < 1 or max_ddim_steps > 1000:
            raise ValueError(f"max_ddim_steps must be between 1 and 1000, got {max_ddim_steps}")

        self.target_prompts = target_prompts
        self.mapping_prompts = mapping_prompts
        self.diagnostic_prompts = diagnostic_prompts
        self.max_ddim_steps = max_ddim_steps
        self.seed = seed
        self.device = device
        self.weight_dtype = weight_dtype
        self.dirty = False  # Track if cache was modified (for lazy caching)

        self.target_prompt_to_idx = {prompt: idx for idx, prompt in enumerate(target_prompts)}
        self.mapping_prompt_to_idx = {prompt: idx for idx, prompt in enumerate(mapping_prompts)}
        self.diagnostic_prompt_to_idx = {prompt: idx for idx, prompt in enumerate(diagnostic_prompts)}

        vae_scale_factor = 8
        latent_h = height // vae_scale_factor
        latent_w = width // vae_scale_factor
        self.latent_image_ids = self._prepare_latent_image_ids(
            1, latent_h // 2, latent_w // 2, device, weight_dtype
        )

        print(f"[Cache] Caching {len(target_prompts)} target prompts × {max_ddim_steps} steps (seed={seed})...")
        print(f"[Cache] Caching {len(mapping_prompts)} mapping prompts...")
        print(f"[Cache] Caching {len(diagnostic_prompts)} diagnostic prompts...")

        self.target_embeddings = self._compute_text_embeddings(
            target_prompts, text_encoders, tokenizers, device, desc="Target embeddings"
        )

        if mapping_prompts:
            self.mapping_embeddings = self._compute_text_embeddings(
                mapping_prompts, text_encoders, tokenizers, device, desc="Mapping embeddings"
            )
        else:
            self.mapping_embeddings = None

        print("[Cache] Computing unconditional embeddings...")
        uncond_embeds, uncond_pooled, uncond_text_ids = compute_text_embeddings(
            "", text_encoders, tokenizers, device
        )
        self.uncond_embeddings = {
            'prompt_embeds': uncond_embeds.cpu(),
            'pooled_prompt_embeds': uncond_pooled.cpu(),
            'text_ids': uncond_text_ids.cpu(),
        }

        if diagnostic_prompts:
            self.diagnostic_embeddings = self._compute_text_embeddings(
                diagnostic_prompts, text_encoders, tokenizers, device, desc="Diagnostic embeddings"
            )
        else:
            self.diagnostic_embeddings = None

        self.latents = self._compute_all_latents(
            transformer, noise_scheduler, num_channels_latents,
            height, width, seed, guidance
        )

        self._print_memory_usage()

    def _prepare_latent_image_ids(self, batch_size, height, width, device, dtype):
        latent_image_ids = torch.zeros(height, width, 3)
        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
        latent_image_ids = latent_image_ids.reshape(height * width, 3)
        return latent_image_ids.to(device=device, dtype=dtype)

    def _compute_text_embeddings(self, prompts, text_encoders, tokenizers, device, desc="Text embeddings"):
        embeddings = {'prompt_embeds': [], 'pooled_prompt_embeds': [], 'text_ids': []}
        for prompt in tqdm(prompts, desc=desc):
            prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
                prompt, text_encoders, tokenizers, device
            )
            embeddings['prompt_embeds'].append(prompt_embeds.cpu())
            embeddings['pooled_prompt_embeds'].append(pooled_prompt_embeds.cpu())
            embeddings['text_ids'].append(text_ids.cpu())

        embeddings['prompt_embeds'] = torch.cat(embeddings['prompt_embeds'], dim=0)
        embeddings['pooled_prompt_embeds'] = torch.cat(embeddings['pooled_prompt_embeds'], dim=0)
        embeddings['text_ids'] = torch.cat(embeddings['text_ids'], dim=0)
        return embeddings

    def _compute_all_latents(self, transformer, noise_scheduler, num_channels_latents, height, width, seed, guidance):
        from utils.esd_utils import flux_pack_latents, calculate_shift, retrieve_timesteps
        import numpy as np

        num_prompts = len(self.target_prompts)
        vae_scale_factor = 8
        latent_h = height // vae_scale_factor
        latent_w = width // vae_scale_factor
        packed_seq_len = (latent_h // 2) * (latent_w // 2)
        packed_channels = num_channels_latents * 4

        latents_cache = torch.zeros(
            num_prompts, self.max_ddim_steps, packed_seq_len, packed_channels,
            dtype=self.weight_dtype, device='cpu'
        )

        generator = torch.Generator(device=self.device).manual_seed(seed)
        shape = (1, num_channels_latents, latent_h, latent_w)
        initial_noise = randn_tensor(shape, generator=generator, dtype=self.weight_dtype, device=self.device)
        guidance_tensor = torch.tensor([guidance], device=self.device, dtype=self.weight_dtype)

        pbar = tqdm(total=num_prompts * self.max_ddim_steps, desc="Latents")

        with torch.no_grad():
            for prompt_idx in range(num_prompts):
                prompt_embeds = self.target_embeddings['prompt_embeds'][prompt_idx:prompt_idx+1].to(self.device)
                pooled_prompt_embeds = self.target_embeddings['pooled_prompt_embeds'][prompt_idx:prompt_idx+1].to(self.device)
                text_ids = self.target_embeddings['text_ids'][prompt_idx:prompt_idx+1].to(self.device)
                if text_ids.dim() == 3:
                    text_ids = text_ids[0]
                text_ids = text_ids.to(dtype=self.weight_dtype)

                for ddim_step in range(1, self.max_ddim_steps + 1):
                    latents = flux_pack_latents(initial_noise.clone(), 1, num_channels_latents, latent_h, latent_w)
                    image_seq_len = latents.shape[1]
                    sigmas = np.linspace(1.0, 1 / ddim_step, ddim_step)
                    mu = calculate_shift(
                        image_seq_len,
                        noise_scheduler.config.base_image_seq_len,
                        noise_scheduler.config.max_image_seq_len,
                        noise_scheduler.config.base_shift,
                        noise_scheduler.config.max_shift,
                    )
                    timesteps_tensor, _ = retrieve_timesteps(noise_scheduler, ddim_step, transformer.device, None, sigmas, mu=mu)
                    latents = latents.to(transformer.device).to(self.weight_dtype)

                    for i, t in enumerate(timesteps_tensor):
                        timestep = t.expand(latents.shape[0]).to(self.weight_dtype)
                        noise_pred = transformer(
                            hidden_states=latents,
                            timestep=timestep / 1000,
                            guidance=guidance_tensor,
                            pooled_projections=pooled_prompt_embeds.to(self.weight_dtype),
                            encoder_hidden_states=prompt_embeds.to(self.weight_dtype),
                            txt_ids=text_ids,
                            img_ids=self.latent_image_ids,
                            return_dict=False,
                        )
                        if isinstance(noise_pred, (tuple, list)):
                            noise_pred = noise_pred[0]
                        latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]

                    latents_cache[prompt_idx, ddim_step - 1] = latents.cpu()
                    pbar.update(1)

        pbar.close()
        return latents_cache

    def get_target(self, prompt: str, ddim_step: int, device: torch.device) -> tuple:
        """Get latent and embeddings for a target prompt."""
        prompt_idx = self.target_prompt_to_idx[prompt]
        latent = self.latents[prompt_idx, ddim_step - 1].unsqueeze(0).to(device)
        prompt_embeds = self.target_embeddings['prompt_embeds'][prompt_idx:prompt_idx+1].to(device)
        pooled_prompt_embeds = self.target_embeddings['pooled_prompt_embeds'][prompt_idx:prompt_idx+1].to(device)
        text_ids = self.target_embeddings['text_ids'][prompt_idx:prompt_idx+1].to(device)
        return latent, prompt_embeds, pooled_prompt_embeds, text_ids, self.latent_image_ids.to(device)

    def get_target_pooled(self, prompt: str, device: torch.device) -> torch.Tensor:
        """Get pooled_prompt_embeds for HyperLoRA context (768-dim)."""
        idx = self.target_prompt_to_idx[prompt]
        return self.target_embeddings['pooled_prompt_embeds'][idx:idx+1].to(device)

    def get_mapping(self, prompt: str, device: torch.device) -> tuple:
        """Get text embeddings for a mapping prompt."""
        if self.mapping_embeddings is None:
            raise ValueError(f"No mapping prompts cached")
        if prompt not in self.mapping_prompt_to_idx:
            raise KeyError(f"Mapping prompt '{prompt}' not in cache. Use add_embedding() first.")
        idx = self.mapping_prompt_to_idx[prompt]
        return (
            self.mapping_embeddings['prompt_embeds'][idx:idx+1].to(device),
            self.mapping_embeddings['pooled_prompt_embeds'][idx:idx+1].to(device),
            self.mapping_embeddings['text_ids'][idx:idx+1].to(device),
        )

    def get_mapping_pooled(self, prompt: str, device: torch.device) -> torch.Tensor:
        """Get pooled_prompt_embeds for a mapping prompt (for HyperLoRA context)."""
        if self.mapping_embeddings is None:
            raise ValueError(f"No mapping prompts cached")
        idx = self.mapping_prompt_to_idx[prompt]
        return self.mapping_embeddings['pooled_prompt_embeds'][idx:idx+1].to(device)

    def add_embedding(
        self,
        prompt: str,
        prompt_embeds: torch.Tensor,
        pooled_prompt_embeds: torch.Tensor,
        text_ids: torch.Tensor,
        embedding_type: str = 'mapping',
    ):

        # Get the appropriate storage based on embedding type
        type_config = {
            'target': (self.target_prompts, self.target_prompt_to_idx, 'target_embeddings'),
            'mapping': (self.mapping_prompts, self.mapping_prompt_to_idx, 'mapping_embeddings'),
            'diagnostic': (self.diagnostic_prompts, self.diagnostic_prompt_to_idx, 'diagnostic_embeddings'),
        }

        prompts_list, prompt_to_idx, embeddings_attr = type_config[embedding_type]

        if prompt in prompt_to_idx:
            return  # Already cached

        embeddings_dict = getattr(self, embeddings_attr)

        # Initialize embeddings structure if needed
        if embeddings_dict is None:
            embeddings_dict = {
                'prompt_embeds': prompt_embeds.cpu(),
                'pooled_prompt_embeds': pooled_prompt_embeds.cpu(),
                'text_ids': text_ids.cpu(),
            }
            setattr(self, embeddings_attr, embeddings_dict)
        else:
            # Concatenate to existing
            embeddings_dict['prompt_embeds'] = torch.cat(
                [embeddings_dict['prompt_embeds'], prompt_embeds.cpu()], dim=0
            )
            embeddings_dict['pooled_prompt_embeds'] = torch.cat(
                [embeddings_dict['pooled_prompt_embeds'], pooled_prompt_embeds.cpu()], dim=0
            )
            embeddings_dict['text_ids'] = torch.cat(
                [embeddings_dict['text_ids'], text_ids.cpu()], dim=0
            )

        # Update prompt tracking
        idx = len(prompts_list)
        prompts_list.append(prompt)
        prompt_to_idx[prompt] = idx
        self.dirty = True
        print(f"[Cache] Added {embedding_type} embedding for: {prompt[:50]}... (total: {len(prompts_list)})")

    def get_uncond(self, device: torch.device) -> tuple:
        """Get unconditional (empty prompt) embeddings."""
        return (
            self.uncond_embeddings['prompt_embeds'].to(device),
            self.uncond_embeddings['pooled_prompt_embeds'].to(device),
            self.uncond_embeddings['text_ids'].to(device),
        )

    def get_diagnostic(self, prompt: str, device: torch.device) -> tuple:
        """Get text embeddings for a diagnostic prompt."""
        if self.diagnostic_embeddings is None:
            raise ValueError(f"No diagnostic prompts cached")
        idx = self.diagnostic_prompt_to_idx[prompt]
        return (
            self.diagnostic_embeddings['prompt_embeds'][idx:idx+1].to(device),
            self.diagnostic_embeddings['pooled_prompt_embeds'][idx:idx+1].to(device),
            self.diagnostic_embeddings['text_ids'][idx:idx+1].to(device),
        )

    def get_diagnostic_pooled(self, prompt: str, device: torch.device) -> torch.Tensor:
        """Get pooled_prompt_embeds for a diagnostic prompt (for HyperLoRA context)."""
        if self.diagnostic_embeddings is None:
            raise ValueError(f"No diagnostic prompts cached")
        idx = self.diagnostic_prompt_to_idx[prompt]
        return self.diagnostic_embeddings['pooled_prompt_embeds'][idx:idx+1].to(device)

    def __contains__(self, prompt: str) -> bool:
        return prompt in self.target_prompt_to_idx

    def __len__(self) -> int:
        return len(self.target_prompts) * self.max_ddim_steps

    def _print_memory_usage(self):
        latent_mem = self.latents.element_size() * self.latents.nelement() / (1024 ** 2)
        target_mem = sum(t.element_size() * t.nelement() for t in self.target_embeddings.values()) / (1024 ** 2)
        mapping_mem = 0
        if self.mapping_embeddings:
            mapping_mem = sum(t.element_size() * t.nelement() for t in self.mapping_embeddings.values()) / (1024 ** 2)
        uncond_mem = sum(t.element_size() * t.nelement() for t in self.uncond_embeddings.values()) / (1024 ** 2)
        diag_mem = 0
        if self.diagnostic_embeddings:
            diag_mem = sum(t.element_size() * t.nelement() for t in self.diagnostic_embeddings.values()) / (1024 ** 2)
        total = latent_mem + target_mem + mapping_mem + uncond_mem + diag_mem
        print(f"[Cache] Memory: latents={latent_mem:.1f}MB, target={target_mem:.1f}MB, mapping={mapping_mem:.1f}MB, uncond={uncond_mem:.1f}MB, diag={diag_mem:.1f}MB, total={total:.1f}MB")

    def save(self, path: str):
        dirname = os.path.dirname(path)
        if dirname:
            os.makedirs(dirname, exist_ok=True)
        data = {
            'target_prompts': self.target_prompts,
            'mapping_prompts': self.mapping_prompts,
            'diagnostic_prompts': self.diagnostic_prompts,
            'max_ddim_steps': self.max_ddim_steps,
            'seed': self.seed,
            'latents': self.latents,
            'target_embeddings': self.target_embeddings,
            'mapping_embeddings': self.mapping_embeddings,
            'uncond_embeddings': self.uncond_embeddings,
            'diagnostic_embeddings': self.diagnostic_embeddings,
            'latent_image_ids': self.latent_image_ids.cpu(),
        }
        torch.save(data, path)
        print(f"[Cache] Saved to {path}")

    @classmethod
    def load(
        cls,
        path: str,
        device: torch.device,
        weight_dtype: torch.dtype = torch.bfloat16,
        expected_target_prompts: List[str] = None,
        expected_mapping_prompts: List[str] = None,
        expected_diagnostic_prompts: List[str] = None,
        expected_ddim_steps: int = None,
        expected_seed: int = None,
    ):
        print(f"[Cache] Loading from {path}...")
        data = torch.load(path, map_location='cpu', weights_only=False)

        instance = object.__new__(cls)
        instance.target_prompts = data['target_prompts']
        instance.mapping_prompts = data.get('mapping_prompts', [])
        instance.diagnostic_prompts = data.get('diagnostic_prompts', [])
        instance.max_ddim_steps = data['max_ddim_steps']
        instance.seed = data.get('seed')

        # Mapping and diagnostic prompts can be added on-the-fly (no strict validation)
        instance.device = device
        instance.weight_dtype = weight_dtype
        instance.target_prompt_to_idx = {prompt: idx for idx, prompt in enumerate(instance.target_prompts)}
        instance.mapping_prompt_to_idx = {prompt: idx for idx, prompt in enumerate(instance.mapping_prompts)}
        instance.diagnostic_prompt_to_idx = {prompt: idx for idx, prompt in enumerate(instance.diagnostic_prompts)}
        instance.latents = data['latents']
        instance.target_embeddings = data['target_embeddings']
        instance.mapping_embeddings = data.get('mapping_embeddings', None)
        instance.uncond_embeddings = data['uncond_embeddings']
        instance.diagnostic_embeddings = data.get('diagnostic_embeddings', None)
        instance.latent_image_ids = data['latent_image_ids'].to(device=device, dtype=weight_dtype)
        instance.dirty = False  # Will be set to True if embeddings are added on-the-fly
        instance._print_memory_usage()
        return instance


class CombinedCFGModel:
    """Wrapper that uses different models for conditional and unconditional passes."""

    def __init__(self, cond_model, uncond_model):
        self.cond_model = cond_model
        self.uncond_model = uncond_model
        self.device = cond_model.device

    def apply_model(self, x, t, c):
        # When DDIMSampler uses guidance, it concatenates [uncond, cond] inputs
        # We split and route to different models
        b2 = x.shape[0]
        assert b2 % 2 == 0
        b = b2 // 2

        x_uncond = x[:b]
        x_cond = x[b:]
        t_uncond = t[:b]
        t_cond = t[b:]

        # Split conditioning
        if isinstance(c, dict):
            c_uncond = {}
            c_cond = {}
            for k in c:
                if isinstance(c[k], list):
                    c_uncond[k] = [v[:b] for v in c[k]]
                    c_cond[k] = [v[b:] for v in c[k]]
                else:
                    c_uncond[k] = c[k][:b]
                    c_cond[k] = c[k][b:]
        else:
            c_uncond = c[:b]
            c_cond = c[b:]

        # Route unconditional to model_orig, conditional to model (with LoRA)
        out_uncond = self.uncond_model.apply_model(x_uncond, t_uncond, c_uncond)
        out_cond = self.cond_model.apply_model(x_cond, t_cond, c_cond)

        return torch.cat([out_uncond, out_cond], dim=0)

    def get_learned_conditioning(self, prompts):
        return self.cond_model.get_learned_conditioning(prompts)

    def decode_first_stage(self, z):
        return self.cond_model.decode_first_stage(z)

    def eval(self):
        self.cond_model.eval()
        self.uncond_model.eval()
        return self

    def __getattr__(self, name):
        return getattr(self.cond_model, name)


def prompt_augmentation(content, augment=True):
    """Generate augmented prompts for a given concept."""
    if augment:
        prompts = [
            # object augmentation
            "{} in a photo".format(content),
            "{} in a snapshot".format(content),
            "A snapshot of {}".format(content),
            "A photograph showcasing {}".format(content),
            "An illustration of {}".format(content),
            "A digital rendering of {}".format(content),
            "A visual representation of {}".format(content),
            "A graphic of {}".format(content),
            "A shot of {}".format(content),
            "A photo of {}".format(content),
            "A black and white image of {}".format(content),
            "A depiction in portrait form of {}".format(content),
            "A scene depicting {} during a public gathering".format(content),
            "{} captured in an image".format(content),
            "A depiction created with oil paints capturing {}".format(content),
            "An image of {}".format(content),
            "A drawing capturing the essence of {}".format(content),
            "An official photograph featuring {}".format(content),
            "A detailed sketch of {}".format(content),
            "{} during sunset/sunrise".format(content),
            "{} in a detailed portrait".format(content),
            "An official photo of {}".format(content),
            "Historic photo of {}".format(content),
            "Detailed portrait of {}".format(content),
            "A painting of {}".format(content),
            "HD picture of {}".format(content),
            "Magazine cover capturing {}".format(content),
            "Painting-like image of {}".format(content),
            "Hand-drawn art of {}".format(content),
            "An oil portrait of {}".format(content),
            "{} in a sketch painting".format(content),
        ]
        return prompts
    else:
        return [content]


def load_config(config_path: str) -> dict:
    """Load configuration from YAML file."""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    # Extract the first config key (e.g., 'MACE')
    config_name = list(config.keys())[0]
    return config[config_name], config_name


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Simplified HyperLoRA Training for Stable Diffusion"
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to YAML configuration file",
    )
    return parser.parse_args()


def create_quick_sampler(model, sampler, image_size: int, ddim_steps: int, ddim_eta: float):
    """Create a quick sampling function with fixed parameters."""
    return lambda conditioning, scale, start_code, till_T: sample_model(
        model,
        sampler,
        conditioning,
        image_size,
        image_size,
        ddim_steps,
        scale,
        ddim_eta,
        start_code=start_code,
        till_T=till_T,
        verbose=False,
    )


def load_text_encoders(class_one, class_two, pretrained_model_name_or_path):
    text_encoder_one = class_one.from_pretrained(
        pretrained_model_name_or_path, subfolder="text_encoder", revision=None, variant=None
    )
    text_encoder_two = class_two.from_pretrained(
        pretrained_model_name_or_path, subfolder="text_encoder_2", revision=None, variant=None
    )

    # Disable gradients + set eval mode
    for enc in (text_encoder_one, text_encoder_two):
        enc.eval()
        for p in enc.parameters():
            p.requires_grad_(False)

    return text_encoder_one, text_encoder_two


def import_model_class_from_model_name_or_path(
        pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path, subfolder=subfolder, revision=revision
    )
    model_class = text_encoder_config.architectures[0]
    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "T5EncoderModel":
        from transformers import T5EncoderModel

        return T5EncoderModel
    else:
        raise ValueError(f"{model_class} is not supported.")


def compute_text_embeddings(prompts, text_encoders, tokenizers, device, max_sequence_length=256):
    # prompts: List[str] or str
    if isinstance(prompts, str):
        prompts = [prompts]

    prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
        text_encoders, tokenizers, prompts, max_sequence_length
    )
    return (
        prompt_embeds.to(device),
        pooled_prompt_embeds.to(device),
        text_ids.to(device),
    )


def compute_clip_pooled_embeddings(prompts, text_encoder, tokenizer, device):

    model_device = text_encoder.text_model.embeddings.token_embedding.weight.device

    # If caller provided a device, you can optionally enforce it,
    # but the safest is to compute on model_device.
    compute_device = model_device if device is None else torch.device(device)

    # If compute_device != model_device, move the model (or just ignore device)
    if compute_device != model_device:
        text_encoder = text_encoder.to(compute_device)
        model_device = compute_device

    pooled_embeds = _get_clip_prompt_embeds(
        text_encoder=text_encoder,
        tokenizer=tokenizer,
        prompt=prompts,
        device=device,
        num_images_per_prompt=1,
    )
    return pooled_embeds.to(device)


def generate_images(
        sampler,
        model,
        prompt: str,
        device: torch.device,
        steps: int = 50,
        eta: float = 0.0,
        batch_size: int = 1,
        start_code: torch.Tensor = None,
        guidance_scale: float = 7.5,
):
    """
    Generate images with CFG from a CompVis SD model + DDIMSampler.
    Uses the same approach as generate_images_nsfw_cfg.py.
    """
    if start_code is None:
        start_code = torch.randn(batch_size, 4, 64, 64, device=device)

    model.eval()
    with torch.no_grad(), torch.autocast(device_type=device.type, enabled=(device.type == "cuda")):
        cond = model.get_learned_conditioning([prompt] * start_code.shape[0])
        uncond = model.get_learned_conditioning([""] * start_code.shape[0])

        samples, _ = sampler.sample(
            S=steps,
            conditioning={"c_crossattn": [cond]},
            batch_size=start_code.shape[0],
            shape=start_code.shape[1:],
            verbose=False,
            unconditional_guidance_scale=guidance_scale,
            unconditional_conditioning={"c_crossattn": [uncond]},
            eta=eta,
            x_T=start_code,
        )
        decoded = model.decode_first_stage(samples)
        decoded = (decoded + 1.0) / 2.0
        decoded = torch.clamp(decoded, 0.0, 1.0)
        return decoded


def main():
    args = parse_args()

    # Load configuration
    config, config_name = load_config(args.config)
    print(f"=== Training with config: {config_name} ===")
    print(f"Config file: {args.config}")

    # Extract key parameters with defaults
    learning_rate_remove = config.get('learning_rate_remove', 1e-5)
    learning_rate_retain = config.get('learning_rate_retain', 1e-5)
    max_train_steps = config.get('max_train_steps', 120)
    hyper_train_steps = config.get('hyper_train_steps', 500)  # Steps for hypernetwork context
    rank = config.get('rank', 1)
    lora_alpha = config.get('lora_alpha', 8)  # LoRA alpha parameter
    internal_size = config.get('internal_size', 100)
    seed = config.get('seed', 2024)
    resolution = config.get('resolution', 512)
    use_pooler = config.get('use_pooler', True)
    use_orig_concat = config.get('use_orig_concat', False)
    gradient_accumulation_steps = config.get('gradient_accumulation_steps', 1)

    # Multi-concept configuration
    concepts = config.get('concepts', [])
    mapping_concept = config.get('mapping_concept', [])
    retain_csv_path = config.get('retain_csv_path', None)  # Path to CSV with retain prompts

    # Augmentation flags
    augment_target = config.get('augment_target', True)  # Whether to augment target concepts
    augment_retain = config.get('augment_retain', False)  # Whether to augment retain prompts from CSV

    # Paths
    output_dir = config.get('output_dir', './output')
    final_save_path = config.get('final_save_path', './saved_model/LoRA_fusion_model')
    pretrained_model_name_or_path = config.get('pretrained_model_name_or_path', "black-forest-labs/FLUX.1-dev")

    # Training settings
    ddim_steps = 28
    negative_guidance = config.get('negative_guidance', 2.0)
    internal_lr = config.get('internal_lr', 1e-4)  # Simulated lr for hypernetwork gradient matching

    # Diagnostic prompts for image generation during training
    diagnostic_prompts = config.get('diagnostic_prompts', [])
    if not diagnostic_prompts:
        # Default diagnostic prompts if none provided
        diagnostic_prompts = [
            f"a photo of {concepts[0]}" if concepts else "a photo of a person",
            "a photo of a cat",
            "a photo of a car"
        ]

    print(f"Training steps: {max_train_steps}")
    print(f"Hypernetwork steps: {hyper_train_steps}")
    print(f"Learning rate (remove): {learning_rate_remove}")
    print(f"Learning rate (retain): {learning_rate_retain}")
    print(f"LoRA rank: {rank}")
    print(f"LoRA alpha: {lora_alpha}")
    print(f"Target concepts: {len(concepts)}")
    print("=" * 48)

    # Set seed
    if seed is not None:
        hf_set_seed(seed)

    # Setup Accelerator
    accelerator_project_config = ProjectConfiguration(
        project_dir=output_dir,
        logging_dir=config.get('logging_dir', 'logs'),
    )

    # Only pass log_with if wandb is available
    log_with = config.get('report_to', 'wandb') if WANDB_AVAILABLE else None
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=config.get('mixed_precision', None),
        log_with=log_with,
        project_config=accelerator_project_config,
    )

    is_main = accelerator.is_main_process

    # Initialize W&B if needed and available
    use_wandb = config.get('report_to') == 'wandb' and WANDB_AVAILABLE
    if is_main and use_wandb:
        wandb.init(
            project="UnHype",
            name=f"{config_name}_training",
            config=config
        )
        wandb.define_metric("learning_rate_remove", summary="last")
        wandb.define_metric("learning_rate_retain", summary="last")
    elif is_main and config.get('report_to') == 'wandb' and not WANDB_AVAILABLE:
        print("Warning: wandb requested but not available. Disabling wandb logging.")

    tokenizer_one = CLIPTokenizer.from_pretrained(
        "openai/clip-vit-large-patch14"
    )

    tokenizer_two = T5TokenizerFast.from_pretrained(
        "google/t5-v1_1-base"
    )

    # import correct text encoder classes
    text_encoder_cls_one = import_model_class_from_model_name_or_path(
        pretrained_model_name_or_path, None
    )
    text_encoder_cls_two = import_model_class_from_model_name_or_path(
        pretrained_model_name_or_path, None, subfolder="text_encoder_2"
    )

    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        pretrained_model_name_or_path, subfolder="scheduler"
    )
    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two,
                                                            pretrained_model_name_or_path)
    text_encoder_one =text_encoder_one.to(accelerator.device)
    text_encoder_two = text_encoder_two.to(accelerator.device)
    vae = AutoencoderKL.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="vae",
        revision=None,
        variant=None,
    )

    weight_dtype = torch.bfloat16 if accelerator.device.type == "cuda" else torch.float32

    model = FluxTransformer2DModel.from_pretrained(
        pretrained_model_name_or_path, torch_dtype=weight_dtype,
        subfolder="transformer", revision=None, variant=None
    ).to(accelerator.device)

    model.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder_one.requires_grad_(False)
    text_encoder_two.requires_grad_(False)

    tokenizers = [tokenizer_one, tokenizer_two]
    text_encoders = [text_encoder_one, text_encoder_two]
    #text_encoder_one = text_encoder_one.to(accelerator.device)
    #text_encoder_two = text_encoder_two.to(accelerator.device)

    # # Load models
    # model_orig, sampler_orig, model, sampler_unused = get_models(
    #     model_config_path, pretrained_model_path, accelerator.device
    # )
    #
    # # Freeze original model
    # for p in model_orig.model.diffusion_model.parameters():
    #     p.requires_grad = False
    # model_orig.eval()
    #
    # # Freeze trainable model backbone
    # for p in model.model.diffusion_model.parameters():
    #     p.requires_grad = False
    #
    # # Setup HyperLoRA
    model.hyper = HypernetworkManager()

    # Flux's pooled_prompt_embeds is always 768-dim (from built-in CLIP text_encoder_one)
    clip_size = 768
    target_modules = ["attn.add_v_proj", "attn.to_v", "attn.to_out.0"]

    hyper_lora_factory = partial(
        HyperLoRALinear,
        clip_size=clip_size,
        rank=rank,
        alpha=lora_alpha,
        train_steps=hyper_train_steps,
        use_orig_concat=use_orig_concat,
        dtype=torch.float32,
        internal_size=internal_size,
    )

    hyper_lora_layers = inject_hyper_lora(
        model, target_modules, hyper_lora_factory
    )

    for layer_name, layer in hyper_lora_layers:
        layer.set_parent_model(model)
        #layer.to(dtype=torch.bfloat16)

    # Setup optimizer
    trainable_params = [p for p in model.parameters() if p.requires_grad]

    if is_main:
        print(f"Total trainable parameter tensors: {len(trainable_params)}")
        print_trainable_parameters(model)

    optimizer_remove = torch.optim.Adam(model.parameters(), lr=learning_rate_remove)
    optimizer_retain = torch.optim.Adam(model.parameters(), lr=learning_rate_retain)

    gamma = config.get('gamma', 0.9)
    step_size = config.get('step_size', 300)

    drop_lr_on_plateau = config.get('drop_lr_on_plateau', False)
    if drop_lr_on_plateau:
        plateau_factor = config.get('plateau_factor', 0.1)
        plateau_patience_remove = config.get('plateau_patience_remove', 10)
        plateau_patience_retain = config.get('plateau_patience_retain', 10)

        scheduler_remove = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_remove,
            mode='min',
            factor=plateau_factor,
            patience=plateau_patience_remove,
            verbose=True
        )
        scheduler_retain = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer_retain,
            mode='min',
            factor=plateau_factor,
            patience=plateau_patience_retain,
            verbose=True
        )
        if is_main:
            print(f"Using separate ReduceLROnPlateau schedulers:")
            print(f"  Remove: lr={learning_rate_remove}, factor={plateau_factor}, patience={plateau_patience_remove}")
            print(f"  Retain: lr={learning_rate_retain}, factor={plateau_factor}, patience={plateau_patience_retain}")
    else:
        scheduler_remove = torch.optim.lr_scheduler.MultiStepLR(
            optimizer_remove, milestones=step_size, gamma=gamma
        )
        scheduler_retain = torch.optim.lr_scheduler.MultiStepLR(
            optimizer_retain, milestones=step_size, gamma=gamma
        )
        if is_main:
            print(f"Using MultiStepLR schedulers (step_size={step_size}, gamma={gamma})")

    # Prepare for distributed training
    model, optimizer_remove, optimizer_retain = accelerator.prepare(model, optimizer_remove, optimizer_retain)

    # Register HyperLoRA layers after prepare
    for layer_name, layer in hyper_lora_layers:
        layer.set_parent_model(accelerator.unwrap_model(model))
        accelerator.unwrap_model(model).hyper.add_hyperlora(layer_name, layer.hyper_lora)

    # Prepare concept embeddings
    target_concepts = []
    # concepts is now a simple list of concept strings
    for concept_text in concepts:
        target_concepts.append(concept_text)

    print(f"Target concepts for removal: {target_concepts}")
    print(f"Total target concepts: {len(target_concepts)}")

    # # Create concept embeddings (not currently used, cached in LatentCache instead)
    # target_embeddings = []
    # for concept in target_concepts:
    #     inputs = encode(concept)
    #     with torch.no_grad():
    #         if use_pooler:
    #             emb = clip_text_encoder(inputs).pooler_output.detach()
    #         else:
    #             emb = clip_text_encoder(inputs).last_hidden_state.detach()
    #     target_embeddings.append(emb)

    # # Mapping concept embeddings (not currently used, may be added later)
    # mapping_embeddings = []
    # for concept in mapping_concept:
    #     inputs = encode(concept)
    #     with torch.no_grad():
    #         if use_pooler:
    #             emb = clip_text_encoder(inputs).pooler_output.detach()
    #         else:
    #             emb = clip_text_encoder(inputs).last_hidden_state.detach()
    #     mapping_embeddings.append(emb)

    # Retain prompts - load from CSV file with 'prompt' column
    # Retain embeddings use CLIP only (no T5) and are cached to disk
    retain_prompts = []
    retain_embeddings = []

    if retain_csv_path and os.path.exists(retain_csv_path):
        print(f"Loading retain prompts from CSV: {retain_csv_path}")
        df = pd.read_csv(retain_csv_path)
        if 'prompt' not in df.columns:
            raise ValueError(f"CSV file must have a 'prompt' column. Found columns: {df.columns.tolist()}")

        base_prompts = df['prompt'].dropna().tolist()
        if augment_retain:
            for prompt in base_prompts:
                if prompt.startswith("A photo of the "):
                    prompt = prompt[len("A photo of the "):]
                augmented = prompt_augmentation(prompt, augment=True)
                retain_prompts.extend(augmented)
        else:
            retain_prompts = base_prompts

        cache_dir = os.path.join(output_dir, "cache")
        if is_main:
            os.makedirs(cache_dir, exist_ok=True)

        csv_name = os.path.basename(retain_csv_path).replace('.csv', '')
        cache_key = f"{csv_name}_aug{augment_retain}_pooler{use_pooler}"
        cache_path = os.path.join(cache_dir, f"retain_embeddings_{cache_key}.pt")

        cache_exists = os.path.exists(cache_path)
        if not cache_exists:
            if is_main:
                for prompt in tqdm(retain_prompts, desc="Creating retain embeddings"):
                    with torch.no_grad():
                        emb = compute_clip_pooled_embeddings(
                            prompt, text_encoder_one, tokenizer_one, accelerator.device
                        )
                    retain_embeddings.append(emb.squeeze().cpu())
                torch.save(retain_embeddings, cache_path)
            accelerator.wait_for_everyone()

        retain_embeddings = torch.load(cache_path, map_location='cpu')
        retain_embeddings = [emb.to(accelerator.device) for emb in retain_embeddings]

    print(f"Mapping concepts: {len(mapping_concept)}")
    print(f"Retain prompts: {len(retain_prompts)} prompts loaded")

    # Create augmented target prompts
    all_augmented_prompts = []
    if augment_target:
        for concept in target_concepts:
            all_augmented_prompts.extend(prompt_augmentation(concept, augment=True))
    else:
        all_augmented_prompts = target_concepts.copy()
    print(f"Total augmented target prompts for caching: {len(all_augmented_prompts)}")

    # Create augmented mapping prompts (must match target augmentation)
    all_augmented_mapping = []
    if len(mapping_concept) > 0:
        # Ensure we have a mapping concept for each target concept
        mapping_per_target = []
        for i, concept in enumerate(target_concepts):
            # Use corresponding mapping concept, or reuse first one if not enough
            if i < len(mapping_concept):
                mapping_per_target.append(mapping_concept[i])
            else:
                mapping_per_target.append(mapping_concept[0])

        # Apply same augmentation as targets
        if augment_target:
            for mapping_text in mapping_per_target:
                all_augmented_mapping.extend(prompt_augmentation(mapping_text, augment=True))
        else:
            all_augmented_mapping = mapping_per_target.copy()

        print(f"Total augmented mapping prompts for caching: {len(all_augmented_mapping)}")

        # Verify lengths match
        if len(all_augmented_mapping) != len(all_augmented_prompts):
            raise ValueError(
                f"Mismatch: {len(all_augmented_prompts)} target prompts but "
                f"{len(all_augmented_mapping)} mapping prompts. They must match!"
            )
    else:
        print("Warning: No mapping concepts provided. Using empty list.")
        all_augmented_mapping = []

    # Unified cache setup
    use_cache = config.get('use_latent_cache', True)
    cache = None

    cache_dir = config.get('cache_dir', 'latents_cache')
    default_cache_name = concepts[0].replace(' ', '_').replace(',', '')[:30] if concepts else 'default'
    cache_name = config.get('cache_name', default_cache_name)
    cache_path = os.path.join(cache_dir, f"{cache_name}_cache.pt")

    if use_cache and is_main:
        cache_seed = seed if seed else 42
        if os.path.exists(cache_path):
            print('before load')
            cache = Cache.load(
                cache_path, accelerator.device, weight_dtype,
                expected_target_prompts=all_augmented_prompts,
                expected_mapping_prompts=all_augmented_mapping,
                expected_diagnostic_prompts=diagnostic_prompts,
                expected_ddim_steps=ddim_steps,
                expected_seed=cache_seed
            )

        if cache is None:
            base_for_cache = accelerator.unwrap_model(model)
            with base_for_cache.hyper.no_lora():
                cache = Cache(
                    target_prompts=all_augmented_prompts,
                    mapping_prompts=all_augmented_mapping,
                    diagnostic_prompts=diagnostic_prompts,
                    transformer=base_for_cache,
                    noise_scheduler=noise_scheduler,
                    text_encoders=text_encoders,
                    tokenizers=tokenizers,
                    device=accelerator.device,
                    max_ddim_steps=ddim_steps,
                    height=512,
                    width=512,
                    num_channels_latents=vae.config.latent_channels,
                    seed=cache_seed,
                    weight_dtype=weight_dtype,
                    guidance=3.0,
                    cache_path=cache_path,
                )
            cache.save(cache_path)

    accelerator.wait_for_everyone()

    criterion = torch.nn.MSELoss()
    losses = []

    # quick_sampler = create_quick_sampler(
    #    accelerator.unwrap_model(model), sampler, resolution, ddim_steps, ddim_eta
    # )
    base = accelerator.unwrap_model(model)
    diag_pipe = FluxPipeline(
        transformer=base,  # <-- THIS is your live transformer
        vae=vae,
        scheduler=noise_scheduler,
        text_encoder=text_encoder_one,
        tokenizer=tokenizer_one,
        text_encoder_2=text_encoder_two,
        tokenizer_2=tokenizer_two,
    )
    #diag_pipe.to(accelerator.device)

    #diag_pipe.transformer.to(device=device, dtype=weight_dtype).eval()
    #diag_pipe.text_encoder.to(device=device, dtype=weight_dtype).eval()
    #diag_pipe.text_encoder_2.to(device=device, dtype=weight_dtype).eval()

    # VAE decode must be fp32
    #diag_pipe.vae.to(device=device, dtype=torch.float32).eval()

    # Make pipeline execution device CUDA
    #diag_pipe = diag_pipe.to(accelerator.device)

    diag_pipe.set_progress_bar_config(disable=True)

    pbar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)

    # Training weights for combining removal and retain losses
    remove_weight = config.get('remove_weight', 1.0)  # Weight for removal loss
    retain_weight = config.get('retain_weight', 0.001)  # Weight for retain loss

    print(f"Loss weights: remove={remove_weight:.3f}, retain={retain_weight:.3f}")

    for iteration in pbar:
        base = accelerator.unwrap_model(model)
        #
        # #optimizer.zero_grad(set_to_none=True)

        #vae_config_shift_factor = diag_pipe.vae.config.shift_factor
        #vae_config_scaling_factor = diag_pipe.vae.config.scaling_factor
        vae_config_block_out_channels = diag_pipe.vae.config.block_out_channels

        # # Random timestep
        #steps = torch.arange(0, ddim_steps, device=accelerator.device)  # (0 to ddim_steps-1)

        # # normalize to (0, 1]
        # s = steps.float() / float(ddim_steps - 1)
        # # Piecewise weights (nudity removal bias)
        # # Early steps dominate structure & semantics
        # w = torch.where(
        #     s > 0.70, torch.full_like(s, 6.0),  # EARLY (semantic / composition)
        #     torch.where(
        #         s > 0.30, torch.full_like(s, 2.5),  # MID
        #         torch.full_like(s, 0.5)  # LATE (texture)
        #     )
        # )
        # # Normalize weights → probabilities
        # probs = w / w.sum()
        #
        # # Sample ONE timestep index
        # idx = torch.multinomial(probs, num_samples=1)

        t_enc_ddpm = torch.randint(
            low=0,
            high=ddim_steps,
            size=(1,),
            device=accelerator.device
        )
        #og_num = round((int(t_enc) / ddim_steps) * 100)
        #og_num_lim = round((int(t_enc + 1) / ddim_steps) * 1000)
        #t_enc_ddpm = torch.randint(og_num, og_num_lim, (1,), device=accelerator.device)
        vae_scale_factor = 2 ** (len(vae_config_block_out_channels))

        num_channels = vae.config.latent_channels

        image_h = 512
        image_w = 512

        latent_h = image_h // vae_scale_factor
        latent_w = image_w // vae_scale_factor
        bsz = 1
        # --- Create dummy latent ---
        model_input = torch.zeros(
            (bsz, num_channels, latent_h, latent_w),
            device=vae.device,
            dtype=weight_dtype,
        )

        # (ESD) start_guidance = 3
        start_guidance = 3
        start_guidance = torch.tensor([start_guidance], device=accelerator.device)
        start_guidance = start_guidance.expand(model_input.shape[0])

        with accelerator.accumulate(model):
            # # REMOVAL LOSS: Push target concepts towards mapping concepts
            # # Use accelerator process index to select a GPU-specific index
            #
            rank = accelerator.process_index
            world_size = accelerator.num_processes

            # All valid indices for THIS GPU only: rank, rank+world_size, ...
            valid_indices = list(range(rank, len(target_concepts), world_size))

            if len(valid_indices) == 0:
                # Fallback in case there are fewer samples than processes
                concept_idx = rank % len(target_concepts)
            else:
                # Randomly pick one index from this GPU's slice
                concept_idx = random.choice(valid_indices)

            target_text = target_concepts[concept_idx]

            # Get corresponding mapping concept
            if len(mapping_concept) > 0:
                mapping_text = (
                    mapping_concept[concept_idx]
                    if concept_idx < len(mapping_concept)
                    else mapping_concept[0]
                )
            else:
                mapping_text = ""  # Fallback to empty if no mapping concepts

            if augment_target:
                augmented_prompts = prompt_augmentation(target_text, augment=True)

                # Shard augmentation indices per rank as well
                valid_aug_indices = list(range(rank, len(augmented_prompts), world_size))
                if len(valid_aug_indices) == 0:
                    aug_idx = rank % len(augmented_prompts)
                else:
                    aug_idx = random.choice(valid_aug_indices)

                target_text_augmented = augmented_prompts[aug_idx]

                # Mapping augmentation (must use same aug_idx as target)
                if mapping_text:
                    augmented_mapping = prompt_augmentation(mapping_text, augment=True)
                    if aug_idx >= len(augmented_mapping):
                        aug_idx = aug_idx % len(augmented_mapping)
                    mapping_text_augmented = augmented_mapping[aug_idx]
                else:
                    mapping_text_augmented = ""
            else:
                target_text_augmented = target_text
                mapping_text_augmented = mapping_text

            # Get pooled embedding for HyperLoRA context (CLIP-only, computed fresh - not from cache)
            hyper_emb_target = compute_clip_pooled_embeddings(
                target_text_augmented, text_encoder_one, tokenizer_one, accelerator.device
            )


            # Get mapping concept embeddings from cache (emb_0 = mapping concept for loss)
            if cache is not None and mapping_text_augmented and mapping_text_augmented in cache.mapping_prompt_to_idx:
                emb_0, pooled_emb_0, text_ids_0 = cache.get_mapping(mapping_text_augmented, accelerator.device)
            else:
                # Fallback: compute on-the-fly and add to cache for future use
                with torch.no_grad():
                    if mapping_text_augmented:
                        emb_0, pooled_emb_0, text_ids_0 = compute_text_embeddings(
                            mapping_text_augmented, text_encoders, tokenizers, accelerator.device
                        )
                        # Add to cache for future iterations (lazy caching)
                        if cache is not None:
                            cache.add_embedding(
                                mapping_text_augmented, emb_0, pooled_emb_0, text_ids_0,
                                embedding_type='mapping'
                            )
                    else:
                        # If no mapping concept, fall back to unconditional
                        emb_0, pooled_emb_0, text_ids_0 = compute_text_embeddings(
                            "", text_encoders, tokenizers, accelerator.device
                        )

            # Get target embeddings and latents from cache
            if cache is not None and target_text_augmented in cache:
                z, emb_p, pooled_emb_p, text_ids_p, latent_image_ids = cache.get_target(
                    target_text_augmented, int(t_enc), accelerator.device
                )
            else:
                with torch.no_grad():
                    emb_p, pooled_emb_p, text_ids_p = compute_text_embeddings(
                        target_text_augmented, text_encoders, tokenizers, accelerator.device
                    )

            #     # Get text conditioning for Stable Diffusion
            #     emb_p = base.get_learned_conditioning([target_text_augmented])  # target prompt (positive)
            #     emb_n = base.get_learned_conditioning([target_text_augmented])  # target prompt (negative, to be erased)
            #     emb_m = base.get_learned_conditioning([target_text_augmented])  # mapping prompt (what target should map to)
            # # Random timestep for HyperLoRA context
            rank = accelerator.process_index
            world_size = accelerator.num_processes
            #
            # # Timesteps assigned to THIS rank: rank, rank + world_size, ...
            valid_timesteps = torch.arange(rank, hyper_train_steps, world_size, device=accelerator.device)
            if valid_timesteps.numel() == 0:
                # Fallback in case hyper_train_steps < world_size
                rtimestep = int(torch.randint(0, hyper_train_steps, (1,), device=accelerator.device))
            else:
                # Sample index into this rank’s slice
                idx = torch.randint(0, valid_timesteps.numel(), (1,), device=accelerator.device)
                rtimestep = int(valid_timesteps[idx])

            print(
                f"[Rank {rank} | Device {accelerator.device}] "
                f"idx={concept_idx} | Target: {target_text_augmented} | Mapping: {mapping_text_augmented} | Timestep: {rtimestep}"
            )

            with torch.no_grad():
                t_ddpm = t_enc_ddpm.to(accelerator.device)  # DON'T cast to bf16
                base.hyper.set_context(hyper_emb_target.to(dtype=weight_dtype),
                                       torch.tensor([rtimestep], dtype=weight_dtype, device=accelerator.device))
                _, current_timestep = base.hyper.get_context()
                base.hyper.compute_and_cache_loras(hyper_emb_target.to(dtype=weight_dtype),
                                                   current_timestep.to(dtype=weight_dtype))
                #base.hyper.retain_grad_for_cached_lora()
                #base.hyper.retain_grad_for_cached_lora()
                if True:
                    #with base.hyper.no_lora():
                    z, latent_image_ids, timestep = latent_sample(model,
                                                            noise_scheduler,
                                                            1,
                                                            model_input.shape[1],
                                                            512,
                                                            512,
                                                            emb_p.to(accelerator.device),
                                                            pooled_emb_p.to(accelerator.device),
                                                            text_ids_p.to(accelerator.device),
                                                            start_guidance,
                                                            int(ddim_steps),
                                                            stop_at_step=int(t_ddpm.item()))
                with base.hyper.no_lora():
                    e_0 = predict_noise(
                        model, z, emb_0.to(dtype=weight_dtype), pooled_emb_0.to(dtype=weight_dtype), text_ids_0, latent_image_ids,
                        guidance=start_guidance,
                        timesteps=timestep,
                        CPU_only=True,
                    )
                    e_p = predict_noise(
                        model, z, emb_p.to(dtype=weight_dtype), pooled_emb_p.to(dtype=weight_dtype), text_ids_p, latent_image_ids,
                        guidance=start_guidance,
                        timesteps=timestep,
                        CPU_only=True,
                    )

            base.hyper.set_context(hyper_emb_target, torch.tensor([rtimestep], device=accelerator.device))
            _, current_timestep = base.hyper.get_context()
            base.hyper.compute_and_cache_loras(hyper_emb_target, current_timestep)
            base.hyper.retain_grad_for_cached_lora()

            with torch.no_grad():
                flat = base.hyper.flatten_cached_from_cache()
                print("LoRA norm:", flat.norm().item(), "maxabs:", flat.abs().max().item())

            with torch.no_grad():
                dp = (e_p - e_0).float()
                print(f"[e_p - e_0] L2={dp.norm().item():.4e}, maxabs={dp.abs().max().item():.4e}")

            with torch.no_grad():
                with base.hyper.no_lora():
                    pred_off = predict_noise(model, z, emb_p.to(dtype=weight_dtype), pooled_emb_p.to(dtype=weight_dtype), text_ids_p, latent_image_ids,
                                guidance=start_guidance, timesteps=timestep, CPU_only=True)
                pred_on = predict_noise(model, z, emb_p.to(dtype=weight_dtype), pooled_emb_p.to(dtype=weight_dtype), text_ids_p, latent_image_ids,
                                guidance=start_guidance, timesteps=timestep, CPU_only=True)
                print("functional delta meanabs:", (pred_on - pred_off).abs().mean().item())

            e_n = predict_noise(model, z, emb_p.to(dtype=weight_dtype), pooled_emb_p.to(dtype=weight_dtype), text_ids_p, latent_image_ids,
                                guidance=start_guidance, timesteps=timestep, CPU_only=True)
            e_0.requires_grad = False
            e_p.requires_grad = False


            #loss_aux = criterion(e_n.to(accelerator.device), e_0.to(accelerator.device) - (
            #           negative_guidance * (e_p.to(accelerator.device) - e_0.to(accelerator.device))))
            loss_aux = criterion(
                e_n.float().to(accelerator.device),
                (e_0 - negative_guidance * (e_p - e_0)).float().to(accelerator.device)
            )

            flat = base.hyper.flatten_cached_from_cache()
            #lora_cache_reg = 1e-3 * flat.float().pow(2).mean()
            #loss_aux = loss_aux# + lora_cache_reg

            print("e_n.requires_grad:", e_n.requires_grad)
            print("loss_aux.requires_grad:", loss_aux.requires_grad)
            accelerator.backward(loss_aux)
            #print_alpha_grad_norms(base.hyper)
            # --- use cached LoRA grads instead of live-tensor grads ---
            grads_flat_t = base.hyper.flatten_cached_grads_from_cache()
            if grads_flat_t is None:
                raise RuntimeError(
                    "No gradients found in cached LoRA tensors. Ensure cache is built with graph intact and retain_grad() was called.")

            grads_flat_t = (-1.0 * internal_lr) * grads_flat_t.detach()

            # ---- norm diagnostics ----
            with torch.no_grad():
                delta_norm = torch.norm(grads_flat_t, p=2).item()
                print(f"[Target Δθ] L2 norm: {delta_norm:.4e}")

            #for p in trainable_params:
            #    if p.grad is not None:
            #        p.grad = None

            _, current_timestep = accelerator.unwrap_model(model).hyper.get_context()
            base.hyper.set_context(hyper_emb_target, current_timestep)
            base.hyper.compute_and_cache_loras(hyper_emb_target, current_timestep)
            tensors_flat_t = base.hyper.flatten_cached_from_cache()

            base.hyper.set_context(hyper_emb_target, (current_timestep + 1))
            base.hyper.compute_and_cache_loras(hyper_emb_target, (current_timestep + 1))
            tensors_flat_t1 = base.hyper.flatten_cached_from_cache()

            # Match the SGD step: (θ_{t+1} - θ_t) ≈ -lr * g_t
            delta_live = tensors_flat_t1 - tensors_flat_t
            loss_remove = remove_weight * criterion(delta_live, grads_flat_t)
            accelerator.backward(loss_remove)

            if len(retain_embeddings) > 0:
                # Sample multiple retain concepts
                num_retain_samples = min(10, len(retain_embeddings))
                sampled_retain_embs = random.sample(retain_embeddings, num_retain_samples)

                # Batch process retain concepts
                batch_retain_embs = (
                    torch.stack(sampled_retain_embs, dim=0)
                        .to(device=accelerator.device, dtype=torch.bfloat16)
                )
                hyper = base.hyper
                batch_prompts = batch_retain_embs.repeat(hyper_train_steps // num_retain_samples, 1)
                B = batch_prompts.shape[0]
                perm = torch.randperm(B, device=batch_prompts.device)
                batch_prompts = batch_prompts[perm]

                # Compute LoRAs at t=0
                dtype = next(hyper.parameters()).dtype  # hyper’s param dtype (bf16 if you casted it)

                hyper.compute_and_cache_loras(
                    batch_prompts.to(dtype=dtype).to(dtype=weight_dtype),
                    torch.zeros(B, device=accelerator.device, dtype=dtype),
                )

                tensors_flat_t0 = hyper.flatten_cached_from_cache()

                #Compute LoRAs at t=1, 2, 3, ... B
                dtype = next(hyper.parameters()).dtype

                t_ = (torch.arange(B, device=accelerator.device, dtype=dtype) % B) + 1
                hyper.compute_and_cache_loras(
                    batch_prompts.to(dtype=dtype),
                    t_,
                )
                tensors_flat_t1 = hyper.flatten_cached_from_cache()

                #Loss: minimize change in LoRA weights across timesteps
                delta = tensors_flat_t1 - tensors_flat_t0
                loss_retain = retain_weight * delta.pow(2).mean()
            else:
                loss_retain = torch.tensor(0.0, device=accelerator.device)
            accelerator.backward(loss_retain)

            loss_remove_log = loss_remove.clone().detach()
            loss_retain_log = loss_retain.clone().detach()

            if accelerator.sync_gradients:
                #before = _snapshot_params(base.hyper)

                optimizer_remove.step()
                optimizer_retain.step()

                optimizer_remove.zero_grad(set_to_none=True)
                optimizer_retain.zero_grad(set_to_none=True)

                alpha_name, alpha_param = None, None
                for n, p in model.named_parameters():
                    if n.endswith(".hyper_lora.alpha"):
                        alpha_name, alpha_param = n, p
                        break

                print("alpha:", alpha_name)

                #after_remove = _snapshot_params(base.hyper)
                #_print_modified(before, after_remove, "after optimizer_remove.step()")

                #after_retain = _snapshot_params(base.hyper)
                #_print_modified(after_remove, after_retain, "after optimizer_retain.step()")
                #_print_modified(before, after_retain, "total after both steps")

                if drop_lr_on_plateau:
                    scheduler_remove.step(loss_remove.detach())
                    scheduler_retain.step(loss_retain.detach())
                else:
                    scheduler_remove.step()
                    scheduler_retain.step()

        # Gather loss across devices
        with torch.no_grad():
            loss_retain_reduced = accelerator.gather(loss_retain_log).mean()
            loss_remove_reduced = accelerator.gather(loss_remove_log).mean()

        losses.append(float(loss_remove_reduced.item() + loss_retain_reduced.item()))

        if accelerator.is_main_process and use_wandb:
            current_lr_remove = optimizer_remove.param_groups[0]['lr']
            current_lr_retain = optimizer_retain.param_groups[0]['lr']
            wandb.log({
                "loss_retain": float(loss_retain_reduced.item()),
                "loss_remove": float(loss_remove_reduced.item()),
                "learning_rate_remove": current_lr_remove,
                "learning_rate_retain": current_lr_retain,
            }, step=iteration)

        if is_main:
            pbar.set_postfix({
                "retain": f"{float(loss_retain_reduced.item()):.6f}",
                "remove": f"{float(loss_remove_reduced.item()):.6f}"
            })

        # Generate sample images periodically
        if is_main and use_wandb and (iteration + 1) % 800 == 0:
            # Generate images for diagnostic prompts from config
            for diag_idx, diag_prompt in enumerate(diagnostic_prompts):

                # Get diagnostic embeddings for HyperLoRA (CLIP-only, computed fresh - not from cache)
                hyper_emb_diag = compute_clip_pooled_embeddings(
                    diag_prompt, text_encoder_one, tokenizer_one, accelerator.device
                )

                # Get cached T5 embeddings for image generation (from cache if available)
                if False:#cache is not None and diag_prompt in cache.diagnostic_prompt_to_idx:
                    cached_text_emb = cache.get_diagnostic(diag_prompt, accelerator.device)
                else:
                    cached_text_emb = compute_text_embeddings(
                        diag_prompt, text_encoders, tokenizers, accelerator.device
                    )

                diag_time_steps = [hyper_train_steps]

                device = accelerator.device
                #weight_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32

                # 2) Ensure pipeline uses the *live* transformer (no copies)
                base = accelerator.unwrap_model(model)
                diag_pipe.transformer = base  # make sure pipe uses current model

                # 3) Move ONLY what you need to GPU for diagnostics (encoders+VAE), then move back
                #    Avoid diag_pipe.to(device) if you're tight on VRAM; move components explicitly.

                imgs_per_prompt = []
                #diag_pipe.text_encoder.to(device=device, dtype=weight_dtype).eval()
                #diag_pipe.text_encoder_2.to(device=device, dtype=weight_dtype).eval()
                #diag_pipe.vae.to(device=device, dtype=weight_dtype).eval()  # single dtype

                #diag_pipe.text_encoder.to(device=device, dtype=weight_dtype).eval()
                #diag_pipe.text_encoder_2.to(device=device, dtype=weight_dtype).eval()

                for h_step in diag_time_steps:
                    h_step_tensor = torch.tensor([h_step], device=device)

                    # Enable these if you want hyper-time to change the result
                    base.hyper.set_context(hyper_emb_diag, h_step_tensor)
                    base.hyper.compute_and_cache_loras(hyper_emb_diag,
                                                       h_step_tensor)


                    #diag_pipe.vae.to(device=device, dtype=torch.float32).eval()

                    imgs = generate_one_image_from_prompt(
                        prompt=diag_prompt,
                        transformer=base,
                        vae=vae,
                        noise_scheduler=noise_scheduler,
                        text_encoders=text_encoders,
                        tokenizers=tokenizers,
                        height=512,
                        width=512,
                        num_inference_steps=28,
                        weight_dtype=weight_dtype,
                        seed=seed,
                        cached_embeddings=cached_text_emb,
                    )

                    imgs_per_prompt.append(imgs)

                # 6) Log a single concatenated image to W&B
                if len(imgs_per_prompt) > 0:
                    row_tensors = []

                    for imgs in imgs_per_prompt:
                        if imgs is None:
                            continue

                        # Take the first image (assumed to be PIL.Image)
                        imgs = to_tensor(imgs).clamp(0, 1)
                        row_tensors.append(imgs)

                    if len(row_tensors) > 0:
                        # Concatenate horizontally to form a row: (C, H, sum_W)
                        row = torch.cat(row_tensors, dim=2)

                        # Clean prompt for wandb key (remove spaces and special chars)
                        safe_key = diag_prompt.replace(" ", "_").replace(",", "")[:50]

                        wandb.log(
                            {
                                f"diagnostic_{diag_idx}_{safe_key}": wandb.Image(
                                    to_pil_image(row),
                                    caption=f"{diag_prompt} | hyper steps: {diag_time_steps}",
                                )
                            },
                            step=iteration,
                        )
            # Move back to CPU (fine)
            #diag_pipe.text_encoder.to("cpu")
            #diag_pipe.text_encoder_2.to("cpu")
            #diag_pipe.vae.to("cpu")

        # Save model
        accelerator.wait_for_everyone()
        if is_main and ((iteration % 100 == 0) or (iteration == max_train_steps - 1)):
            print(f"Final loss: {losses[-1]:.6f}")
            print(f"Average loss: {sum(losses) / len(losses):.6f}")

            # Create output directory
            os.makedirs(output_dir, exist_ok=True)
            os.makedirs(final_save_path, exist_ok=True)

            # Save LoRA weights
            model_unwrapped = accelerator.unwrap_model(model)
            lora_state_dict = {k: v.cpu() for k, v in model_unwrapped.state_dict().items() if ".hyper_lora." in k}
            lora_path = os.path.join(final_save_path, f"hyper_lora_em_{iteration}.pth")
            accelerator.save(lora_state_dict, lora_path)
            print(f"Model saved to: {lora_path}")

        # Save config
        config_save = {
            "config_name": config_name,
            "concepts": concepts,
            "mapping_concept": mapping_concept,
            "retain_csv_path": retain_csv_path,
            "augment_target": augment_target,
            "augment_retain": augment_retain,
            "num_retain_prompts": len(retain_prompts),
            "rank": rank,
            "learning_rate_remove": learning_rate_remove,
            "learning_rate_retain": learning_rate_retain,
            "max_train_steps": max_train_steps,
            "hyper_train_steps": hyper_train_steps,
            "final_loss": losses[-1],
            "average_loss": sum(losses) / len(losses),
        }

        with open(os.path.join(final_save_path, "train_config.json"), "w") as f:
            json.dump(config_save, f, indent=2)

    # Save cache if it was modified (lazy caching added new embeddings)
    if is_main and cache is not None and cache.dirty:
        print(f"[Cache] Saving updated cache with {len(cache.mapping_prompts)} mapping prompts...")
        cache.save(cache_path)

    if is_main and use_wandb:
        wandb.finish()


if __name__ == "__main__":
    main()
