"""
Trainer with Normalized Mask-Weighted Loss
在原有trainer基础上添加归一化mask加权loss功能
使用公式:
  diff_squared = (pred - target) ** 2
  weights = torch.ones_like(mask) + (lambda_weight - 1.0) * mask
  numerator = (diff_squared * weights).sum()
  denominator = weights.sum() + 1e-8
  loss = numerator / denominator
"""

import lightning as L
from diffusers.pipelines import FluxPipeline
import torch
import torch.nn.functional as F
import wandb
import os
import yaml
from peft import LoraConfig, get_peft_model_state_dict
from torch.utils.data import DataLoader
import time

from typing import List

import prodigyopt

from ..pipeline.flux_omini import transformer_forward, encode_images


def get_rank():
    try:
        rank = int(os.environ.get("LOCAL_RANK"))
    except:
        rank = 0
    return rank


def get_config():
    config_path = os.environ.get("OMINI_CONFIG")
    assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def init_wandb(wandb_config, run_name):
    import wandb

    try:
        assert os.environ.get("WANDB_API_KEY") is not None
        wandb.init(
            project=wandb_config["project"],
            name=run_name,
            config={},
        )
    except Exception as e:
        print("Failed to initialize WanDB:", e)


class OminiModel(L.LightningModule):
    def __init__(
        self,
        flux_pipe_id: str,
        lora_path: str = None,
        lora_config: dict = None,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        model_config: dict = {},
        adapter_names: List[str] = [None, None, "default"],
        optimizer_config: dict = None,
        gradient_checkpointing: bool = False,
    ):
        # Initialize the LightningModule
        super().__init__()
        self.model_config = model_config
        self.optimizer_config = optimizer_config

        # Load the Flux pipeline
        self.flux_pipe: FluxPipeline = FluxPipeline.from_pretrained(
            flux_pipe_id, torch_dtype=dtype
        ).to(device)
        self.transformer = self.flux_pipe.transformer
        self.transformer.gradient_checkpointing = gradient_checkpointing
        self.transformer.train()

        # Freeze the Flux pipeline
        self.flux_pipe.text_encoder.requires_grad_(False).eval()
        self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
        self.flux_pipe.vae.requires_grad_(False).eval()
        self.adapter_names = adapter_names
        self.adapter_set = set([each for each in adapter_names if each is not None])

        # Initialize LoRA layers
        self.lora_layers = self.init_lora(lora_path, lora_config)

        self.to(device).to(dtype)

    def init_lora(self, lora_path: str, lora_config: dict):
        assert lora_path or lora_config
        if lora_path:
            # TODO: Implement this
            raise NotImplementedError
        else:
            for adapter_name in self.adapter_set:
                self.transformer.add_adapter(
                    LoraConfig(**lora_config), adapter_name=adapter_name
                )
            # TODO: Check if this is correct (p.requires_grad)
            lora_layers = filter(
                lambda p: p.requires_grad, self.transformer.parameters()
            )
        return list(lora_layers)

    def save_lora(self, path: str):
        for adapter_name in self.adapter_set:
            FluxPipeline.save_lora_weights(
                save_directory=path,
                weight_name=f"{adapter_name}.safetensors",
                transformer_lora_layers=get_peft_model_state_dict(
                    self.transformer, adapter_name=adapter_name
                ),
                safe_serialization=True,
            )

    def save_checkpoint(self, path: str, trainer, epoch: int, global_step: int, save_optimizer: bool = True):
        """保存完整的训练状态，包括optimizer、epoch、step等"""
        os.makedirs(path, exist_ok=True)
        
        # 保存LoRA权重
        self.save_lora(path)
        
        # 保存训练状态到CPU，避免显存泄漏
        with torch.no_grad():
            checkpoint = {
                'epoch': epoch,
                'global_step': global_step,
            }
            
            # 可选：保存optimizer状态（会占用额外显存）
            if save_optimizer:
                try:
                    optimizer_state = self.optimizers().state_dict()
                    # 将optimizer状态移到CPU，避免显存泄漏
                    if 'state' in optimizer_state:
                        for param_id, state in optimizer_state['state'].items():
                            for k, v in state.items():
                                if isinstance(v, torch.Tensor):
                                    state[k] = v.cpu()
                    checkpoint['optimizer_state_dict'] = optimizer_state
                    del optimizer_state
                    print("  ✓ Saved optimizer state")
                except Exception as e:
                    print(f"  ⚠️  Failed to save optimizer state: {e}")
                    print("  Saving without optimizer state")
            
            checkpoint_path = os.path.join(path, 'training_state.pt')
            torch.save(checkpoint, checkpoint_path)
            
            # 显式删除临时变量并清理显存
            del checkpoint
        
        # 清理显存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print(f"  Saved checkpoint to {checkpoint_path}")

    def load_checkpoint(self, path: str):
        """从checkpoint恢复训练状态"""
        checkpoint_path = os.path.join(path, 'training_state.pt')
        if not os.path.exists(checkpoint_path):
            print(f"  No checkpoint found at {checkpoint_path}")
            return None
        
        # 先加载到CPU，避免显存峰值
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        print(f"  Loaded checkpoint from {checkpoint_path}")
        print(f"  Resuming from epoch {checkpoint['epoch']}, step {checkpoint['global_step']}")
        return checkpoint

    def configure_optimizers(self):
        # Freeze the transformer
        self.transformer.requires_grad_(False)
        opt_config = self.optimizer_config

        # Set the trainable parameters
        self.trainable_params = self.lora_layers

        # Unfreeze trainable parameters
        for p in self.trainable_params:
            p.requires_grad_(True)

        # Initialize the optimizer
        if opt_config["type"] == "AdamW":
            optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
        elif opt_config["type"] == "Prodigy":
            optimizer = prodigyopt.Prodigy(
                self.trainable_params,
                **opt_config["params"],
            )
        elif opt_config["type"] == "SGD":
            optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
        else:
            raise NotImplementedError("Optimizer not implemented.")
        
        # Restore optimizer state if resuming from checkpoint
        if hasattr(self, '_checkpoint_optimizer_state') and self._checkpoint_optimizer_state is not None:
            print("  Restoring optimizer state from checkpoint...")
            
            try:
                # Move optimizer state to the correct device
                optimizer_state = self._checkpoint_optimizer_state
                
                # 确保所有optimizer状态的tensor都在正确的设备上
                if 'state' in optimizer_state:
                    for param_id, state in optimizer_state['state'].items():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.to(self.device, non_blocking=True)
                
                optimizer.load_state_dict(optimizer_state)
                print("  ✓ Successfully restored optimizer state")
            except Exception as e:
                print(f"  ⚠️  Failed to restore optimizer state: {e}")
                print("  Starting with fresh optimizer instead")
            
            delattr(self, '_checkpoint_optimizer_state')
        elif hasattr(self, '_checkpoint_optimizer_state'):
            print("  No optimizer state in checkpoint, starting with fresh optimizer")
            delattr(self, '_checkpoint_optimizer_state')
        
        return optimizer

    def training_step(self, batch, batch_idx):
        imgs, prompts = batch["image"], batch["description"]
        image_latent_mask = batch.get("image_latent_mask", None)
        target_mask = batch.get("target_mask", None)  # 获取target_mask用于加权loss

        # Get the conditions and position deltas from the batch
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))

        # Prepare inputs
        with torch.no_grad():
            # Prepare image input
            x_0, img_ids = encode_images(self.flux_pipe, imgs)

            # Prepare text input
            (
                prompt_embeds,
                pooled_prompt_embeds,
                text_ids,
            ) = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                prompt_embeds=None,
                pooled_prompt_embeds=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
                max_sequence_length=self.model_config.get("max_sequence_length", 512),
                lora_scale=None,
            )

            # Prepare t and x_t
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            if image_latent_mask is not None:
                x_0 = x_0[:, image_latent_mask[0]]
                x_1 = x_1[:, image_latent_mask[0]]
                x_t = x_t[:, image_latent_mask[0]]
                img_ids = img_ids[image_latent_mask[0]]

            # Prepare conditions
            condition_latents, condition_ids = [], []
            for cond, p_delta, p_scale, latent_mask in zip(
                conditions, position_deltas, position_scales, latent_masks
            ):
                # Prepare conditions
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                # Scale the position (see OminiConrtol2)
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                # Add position delta (see OminiControl)
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                # if len(p_delta) > 1:
                #     print("Warning: only the first position delta is used.")
                # Append to the list
                if latent_mask is not None:
                    c_latents, c_ids = c_latents[latent_mask], c_ids[latent_mask[0]]
                condition_latents.append(c_latents)
                condition_ids.append(c_ids)

            # Prepare guidance
            guidance = (
                torch.ones_like(t).to(self.device)
                if self.transformer.config.guidance_embeds
                else None
            )

        branch_n = 2 + len(conditions)
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        # Disable the attention cross different condition branches
        group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
        # Disable the attention from condition branches to image branch and text branch
        if self.model_config.get("independent_condition", False):
            group_mask[2:, :2] = False

        # Forward pass
        transformer_out = transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            # There are three timesteps for the three branches
            # (text, image, and the condition)
            timesteps=[t, t] + [torch.zeros_like(t)] * len(conditions),
            # Same as above
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            # The LoRA adapter names of each branch
            adapters=self.adapter_names,
            return_dict=False,
            group_mask=group_mask,
        )
        pred = transformer_out[0]

        # Compute mask-weighted loss
        target = x_1 - x_0
        
        if target_mask is not None:
            # 从配置获取mask权重 (lambda_weight)
            lambda_weight = self.model_config.get("mask_loss_weight", 20.0)
            
            # 将target_mask下采样到latent空间
            # target_mask shape: [B, 1, H, W] 例如 [B, 1, 512, 512]
            # pred shape: [B, latent_h*latent_w, C]
            # 计算latent空间的尺寸
            batch_size = pred.shape[0]
            num_latent_tokens = pred.shape[1]
            latent_size = int(num_latent_tokens ** 0.5)  # 假设是正方形，例如64
            
            target_mask_latent = F.interpolate(
                target_mask, 
                size=(latent_size, latent_size), 
                mode='bilinear', 
                align_corners=False
            )
            # target_mask_latent shape: [B, 1, latent_size, latent_size]
            
            # 归一化mask到[0, 1]
            target_mask_latent = target_mask_latent / 255.0 if target_mask_latent.max() > 1.0 else target_mask_latent
            
            # Reshape mask to match pred: [B, latent_h*latent_w, 1]
            mask_flat = target_mask_latent.view(batch_size, 1, -1).permute(0, 2, 1)  # [B, num_latent_tokens, 1]
            
            # 计算归一化的加权MSE loss
            # diff_squared = (pred - target) ** 2
            diff = pred - target  # [B, num_latent_tokens, C]
            diff_squared = diff ** 2  # [B, num_latent_tokens, C]
            
            # weights = torch.ones_like(mask) + (lambda_weight - 1.0) * mask
            weights = torch.ones_like(mask_flat) + (lambda_weight - 1.0) * mask_flat  # [B, num_latent_tokens, 1]
            
            # numerator = (diff_squared * weights).sum()
            # denominator = weights.sum() + 1e-8
            # loss = numerator / denominator
            numerator = (diff_squared * weights).sum()  # 对所有元素求和
            denominator = weights.sum() + 1e-8  # 权重总和 + epsilon
            step_loss = numerator / denominator
            
            # 记录mask区域和非mask区域的平均loss（用于调试）
            with torch.no_grad():
                # mask_flat: [B, num_latent_tokens, 1]
                # 创建mask的布尔版本
                mask_bool = (mask_flat.squeeze(-1) > 0.5).unsqueeze(-1).expand_as(diff)  # [B, num_latent_tokens, C]
                
                if mask_bool.any():
                    mask_region_loss = (diff[mask_bool] ** 2).mean()
                else:
                    mask_region_loss = 0
                
                if (~mask_bool).any():
                    non_mask_region_loss = (diff[~mask_bool] ** 2).mean()
                else:
                    non_mask_region_loss = 0
                
                # 记录到wandb
                if hasattr(self, 'logger') and self.logger is not None:
                    try:
                        wandb.log({
                            "train/mask_region_loss": mask_region_loss.item() if isinstance(mask_region_loss, torch.Tensor) else mask_region_loss,
                            "train/non_mask_region_loss": non_mask_region_loss.item() if isinstance(non_mask_region_loss, torch.Tensor) else non_mask_region_loss,
                            "train/mask_ratio": target_mask_latent.mean().item(),
                        })
                    except:
                        pass
        else:
            # 如果没有target_mask，使用标准MSE loss
            step_loss = torch.nn.functional.mse_loss(pred, target, reduction="mean")
        
        self.last_t = t.mean().item()

        self.log_loss = (
            step_loss.item()
            if not hasattr(self, "log_loss")
            else self.log_loss * 0.95 + step_loss.item() * 0.05
        )
        return step_loss

    def generate_a_sample(self):
        pass


class TrainingCallback(L.Callback):
    def __init__(
        self,
        test_function,
        training_config,
        save_path,
        run_name,
        resume_step=0,
    ):
        super().__init__()
        self.test_function = test_function
        self.training_config = training_config
        self.save_path = save_path
        self.run_name = run_name
        self.sample_interval = training_config.get("sample_interval", 500)
        self.save_interval = training_config.get("save_interval", 500)
        self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
        self.total_steps = resume_step
        self.step_start_time = time.time()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        self.total_steps += 1

        # Calculate timing
        step_time = time.time() - self.step_start_time
        self.step_start_time = time.time()

        # Calculate gradients
        gradient_size = sum(
            [p.grad.norm().item() for p in pl_module.trainable_params if p.grad is not None]
        )
        max_gradient_size = max(
            [p.grad.norm().item() for p in pl_module.trainable_params if p.grad is not None]
        )

        # Calculate average loss
        loss_window_size = 100
        if not hasattr(self, "loss_history"):
            self.loss_history = []
        self.loss_history.append(pl_module.log_loss)
        if len(self.loss_history) > loss_window_size:
            self.loss_history.pop(0)
        avg_loss = sum(self.loss_history) / len(self.loss_history)

        # Log to WandB
        try:
            optimizer = trainer.optimizers[0]
            current_lr = optimizer.param_groups[0]['lr']
            wandb.log({
                "train/loss": pl_module.log_loss,
                "train/avg_loss": avg_loss,
                "train/lr": current_lr,
                "train/gradient_norm": gradient_size,
                "train/max_gradient": max_gradient_size,
                "train/step_time": step_time,
                "train/steps": self.total_steps,
            })
        except:
            pass

        # Print progress
        if self.total_steps % self.print_every_n_steps == 0:
            optimizer = trainer.optimizers[0]
            current_lr = optimizer.param_groups[0]['lr']
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, "
                f"Loss: {pl_module.log_loss:.4f}, Avg Loss: {avg_loss:.4f}, "
                f"LR: {current_lr:.2e}, "
                f"Grad: {gradient_size:.4f}, Max Grad: {max_gradient_size:.4f}, "
                f"Time/Step: {step_time:.3f}s"
            )

        # Save checkpoint (LoRA weights + training state) at specified intervals
        if self.total_steps % self.save_interval == 0:
            save_opt = self.training_config.get("save_optimizer_state", True)
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving checkpoint"
            )
            pl_module.save_checkpoint(
                f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}",
                trainer,
                trainer.current_epoch,
                self.total_steps,
                save_optimizer=save_opt
            )

        # Generate and save a sample image at specified intervals
        if self.total_steps % self.sample_interval == 0 and self.test_function:
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
            )
            pl_module.eval()
            self.test_function(
                pl_module,
                f"{self.save_path}/{self.run_name}/sample",
                f"step_{self.total_steps}",
            )
            pl_module.train()


def train(dataset, trainable_model, config, test_function=None):
    # Get the rank
    rank = get_rank()
    is_main_process = rank == 0

    # Get the training config
    training_config = config["train"]
    wandb_config = training_config.get("wandb", {})

    # Set the CUDA devices
    cuda_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
    print(f"Using CUDA devices: {cuda_devices}")
    device_count = len(cuda_devices.split(","))

    # Initialize WandB on the main process
    import time
    run_name = time.strftime("%Y%m%d-%H%M%S")
    if is_main_process:
        print(f"Run name: {run_name}")
        init_wandb(wandb_config, run_name)

    # Create DataLoader
    train_loader = DataLoader(
        dataset,
        batch_size=training_config["batch_size"],
        shuffle=True,
        num_workers=training_config.get("dataloader_workers", 4),
        persistent_workers=True if training_config.get("dataloader_workers", 4) > 0 else False,
    )

    # Check for checkpoint resumption
    resume_checkpoint = training_config.get("resume_from_checkpoint", None)
    resume_step = 0
    resume_epoch = 0

    # Create callback
    callbacks = [
        TrainingCallback(
            test_function,
            training_config,
            training_config.get("save_path", "./output"),
            run_name,
            resume_step=resume_step,
        )
    ] if is_main_process else []

    # Load checkpoint if resuming
    if resume_checkpoint and os.path.exists(resume_checkpoint):
        checkpoint_data = trainable_model.load_checkpoint(resume_checkpoint)
        if checkpoint_data:
            resume_step = checkpoint_data['global_step']
            resume_epoch = checkpoint_data['epoch']
            # Restore optimizer state after optimizer is initialized
            # This will be done after the first training step via configure_optimizers
            trainable_model._checkpoint_optimizer_state = checkpoint_data.get('optimizer_state_dict', None)
            if is_main_process:
                callbacks[0].total_steps = resume_step
    
    # Save the training config
    save_path = training_config.get("save_path", "./output")
    if is_main_process:
        os.makedirs(f"{save_path}/{run_name}", exist_ok=True)
        with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
            yaml.dump(config, f)

    # Initialize the trainer
    trainer = L.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=device_count,
        precision="bf16-mixed" if config["dtype"] == "bfloat16" else "32-true",
        callbacks=callbacks,
        max_epochs=training_config.get("max_epochs", -1),
        max_steps=training_config.get("max_steps", -1),
        accumulate_grad_batches=training_config.get("accumulate_grad_batches", 1),
        log_every_n_steps=1,
        logger=False,
    )

    # Set training config in trainer and model for callback access
    setattr(trainer, "training_config", training_config)
    setattr(trainable_model, "training_config", training_config)

    # Train the model
    print("=" * 70)
    print("Starting training...")
    print("=" * 70)
    trainer.fit(trainable_model, train_loader)

    print("Training completed!")
