import lightning as L
from diffusers.pipelines import FluxPipeline
import torch
import wandb
import os
import yaml
from peft import LoraConfig, get_peft_model_state_dict
from torch.utils.data import DataLoader
import time

from typing import List

import prodigyopt

from ..pipeline.flux_omini import transformer_forward, encode_images


def get_rank():
    try:
        rank = int(os.environ.get("LOCAL_RANK"))
    except:
        rank = 0
    return rank


def get_config():
    config_path = os.environ.get("OMINI_CONFIG")
    assert config_path is not None, "Please set the OMINI_CONFIG environment variable"
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


def init_wandb(wandb_config, run_name):
    import wandb

    try:
        assert os.environ.get("WANDB_API_KEY") is not None
        wandb.init(
            project=wandb_config["project"],
            name=run_name,
            config={},
        )
    except Exception as e:
        print("Failed to initialize WanDB:", e)


class OminiModel(L.LightningModule):
    def __init__(
        self,
        flux_pipe_id: str,
        lora_path: str = None,
        lora_config: dict = None,
        device: str = "cuda",
        dtype: torch.dtype = torch.bfloat16,
        model_config: dict = {},
        adapter_names: List[str] = [None, None, "default"],
        optimizer_config: dict = None,
        gradient_checkpointing: bool = False,
    ):
        # Initialize the LightningModule
        super().__init__()
        self.model_config = model_config
        self.optimizer_config = optimizer_config

        # Load the Flux pipeline
        self.flux_pipe: FluxPipeline = FluxPipeline.from_pretrained(
            flux_pipe_id, torch_dtype=dtype
        ).to(device)
        self.transformer = self.flux_pipe.transformer
        self.transformer.gradient_checkpointing = gradient_checkpointing
        self.transformer.train()

        # Freeze the Flux pipeline
        self.flux_pipe.text_encoder.requires_grad_(False).eval()
        self.flux_pipe.text_encoder_2.requires_grad_(False).eval()
        self.flux_pipe.vae.requires_grad_(False).eval()
        self.adapter_names = adapter_names
        self.adapter_set = set([each for each in adapter_names if each is not None])

        # Initialize LoRA layers
        self.lora_layers = self.init_lora(lora_path, lora_config)

        self.to(device).to(dtype)

    def init_lora(self, lora_path: str, lora_config: dict):
        assert lora_path or lora_config
        
        # 如果提供了 lora_path，从 checkpoint 加载权重
        if lora_path:
            print(f"  Loading LoRA weights from {lora_path}...")
            for adapter_name in self.adapter_set:
                weight_file = os.path.join(lora_path, f"{adapter_name}.safetensors")
                
                if not os.path.exists(weight_file):
                    print(f"    ⚠️  Warning: {weight_file} not found, skipping {adapter_name}")
                    continue
                
                try:
                    # 使用 FluxPipeline 的 load_lora_weights 方法（会自动创建 adapter）
                    self.flux_pipe.load_lora_weights(
                        lora_path,
                        weight_name=f"{adapter_name}.safetensors",
                        adapter_name=adapter_name
                    )
                    print(f"    ✓ Loaded {adapter_name}")
                except Exception as e:
                    print(f"    ❌ Failed to load {adapter_name}: {e}")
            
            # 🔧 关键修复：显式激活所有加载的 adapters
            print(f"  Activating all adapters: {list(self.adapter_set)}")
            try:
                self.transformer.set_adapters(list(self.adapter_set))
                print(f"    ✓ All adapters activated")
            except Exception as e:
                print(f"    ⚠️  set_adapters failed: {e}, trying enable_adapters...")
                for adapter_name in self.adapter_set:
                    try:
                        self.transformer.enable_adapters(adapter_name)
                    except:
                        pass
        else:
            # 从头训练，创建新的空 LoRA adapters
            for adapter_name in self.adapter_set:
                self.transformer.add_adapter(
                    LoraConfig(**lora_config), adapter_name=adapter_name
                )
            # 激活所有 adapters
            try:
                self.transformer.set_adapters(list(self.adapter_set))
            except:
                pass
        
        # TODO: Check if this is correct (p.requires_grad)
        lora_layers = filter(
            lambda p: p.requires_grad, self.transformer.parameters()
        )
        return list(lora_layers)

    def save_lora(self, path: str):
        for adapter_name in self.adapter_set:
            FluxPipeline.save_lora_weights(
                save_directory=path,
                weight_name=f"{adapter_name}.safetensors",
                transformer_lora_layers=get_peft_model_state_dict(
                    self.transformer, adapter_name=adapter_name
                ),
                safe_serialization=True,
            )
    
    def save_checkpoint(self, path: str, trainer, epoch: int, global_step: int, save_optimizer: bool = True):
        """保存完整的训练状态，包括optimizer、epoch、step等"""
        os.makedirs(path, exist_ok=True)
        
        # 保存LoRA权重
        self.save_lora(path)
        
        # 保存训练状态到CPU，避免显存泄漏
        with torch.no_grad():
            checkpoint = {
                'epoch': epoch,
                'global_step': global_step,
            }
            
            # 可选：保存optimizer状态（会占用额外显存）
            if save_optimizer:
                try:
                    optimizer_state = self.optimizers().state_dict()
                    # 将optimizer状态移到CPU，避免显存泄漏
                    if 'state' in optimizer_state:
                        for param_id, state in optimizer_state['state'].items():
                            for k, v in state.items():
                                if isinstance(v, torch.Tensor):
                                    state[k] = v.cpu()
                    checkpoint['optimizer_state_dict'] = optimizer_state
                    del optimizer_state
                    print("  ✓ Saved optimizer state")
                except Exception as e:
                    print(f"  ⚠️  Failed to save optimizer state: {e}")
                    print("  Saving without optimizer state")
            
            checkpoint_path = os.path.join(path, 'training_state.pt')
            torch.save(checkpoint, checkpoint_path)
            
            # 显式删除临时变量并清理显存
            del checkpoint
        
        # 清理显存
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print(f"  Saved checkpoint to {checkpoint_path}")
    
    def load_checkpoint(self, path: str):
        """从checkpoint恢复训练状态"""
        checkpoint_path = os.path.join(path, 'training_state.pt')
        if not os.path.exists(checkpoint_path):
            print(f"  No checkpoint found at {checkpoint_path}")
            return None
        
        # 先加载到CPU，避免显存峰值
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        print(f"  Loaded checkpoint from {checkpoint_path}")
        print(f"  Resuming from epoch {checkpoint['epoch']}, step {checkpoint['global_step']}")
        return checkpoint

    def configure_optimizers(self):
        # Freeze the transformer
        self.transformer.requires_grad_(False)
        opt_config = self.optimizer_config

        # Set the trainable parameters
        self.trainable_params = self.lora_layers

        # Unfreeze trainable parameters
        for p in self.trainable_params:
            p.requires_grad_(True)

        # Initialize the optimizer
        if opt_config["type"] == "AdamW":
            optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"])
        elif opt_config["type"] == "Prodigy":
            optimizer = prodigyopt.Prodigy(
                self.trainable_params,
                **opt_config["params"],
            )
        elif opt_config["type"] == "SGD":
            optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"])
        else:
            raise NotImplementedError("Optimizer not implemented.")
        
        # Restore optimizer state if resuming from checkpoint
        if hasattr(self, '_checkpoint_optimizer_state') and self._checkpoint_optimizer_state is not None:
            print("  Restoring optimizer state from checkpoint...")
            
            try:
                # Move optimizer state to the correct device
                optimizer_state = self._checkpoint_optimizer_state
                
                # 确保所有optimizer状态的tensor都在正确的设备上
                if 'state' in optimizer_state:
                    for param_id, state in optimizer_state['state'].items():
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.to(self.device, non_blocking=True)
                
                optimizer.load_state_dict(optimizer_state)
                print("  ✓ Successfully restored optimizer state")
            except Exception as e:
                print(f"  ⚠️  Failed to restore optimizer state: {e}")
                print("  Starting with fresh optimizer instead")
            
            delattr(self, '_checkpoint_optimizer_state')
        elif hasattr(self, '_checkpoint_optimizer_state'):
            print("  No optimizer state in checkpoint, starting with fresh optimizer")
            delattr(self, '_checkpoint_optimizer_state')
        
        return optimizer

    def training_step(self, batch, batch_idx):
        imgs, prompts = batch["image"], batch["description"]
        image_latent_mask = batch.get("image_latent_mask", None)

        # Get the conditions and position deltas from the batch
        conditions, position_deltas, position_scales, latent_masks = [], [], [], []
        for i in range(1000):
            if f"condition_{i}" not in batch:
                break
            conditions.append(batch[f"condition_{i}"])
            position_deltas.append(batch.get(f"position_delta_{i}", [[0, 0]]))
            position_scales.append(batch.get(f"position_scale_{i}", [1.0])[0])
            latent_masks.append(batch.get(f"condition_latent_mask_{i}", None))

        # Prepare inputs
        with torch.no_grad():
            # Prepare image input
            x_0, img_ids = encode_images(self.flux_pipe, imgs)

            # Prepare text input
            (
                prompt_embeds,
                pooled_prompt_embeds,
                text_ids,
            ) = self.flux_pipe.encode_prompt(
                prompt=prompts,
                prompt_2=None,
                prompt_embeds=None,
                pooled_prompt_embeds=None,
                device=self.flux_pipe.device,
                num_images_per_prompt=1,
                max_sequence_length=self.model_config.get("max_sequence_length", 512),
                lora_scale=None,
            )

            # Prepare t and x_t
            t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device))
            x_1 = torch.randn_like(x_0).to(self.device)
            t_ = t.unsqueeze(1).unsqueeze(1)
            x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype)
            if image_latent_mask is not None:
                x_0 = x_0[:, image_latent_mask[0]]
                x_1 = x_1[:, image_latent_mask[0]]
                x_t = x_t[:, image_latent_mask[0]]
                img_ids = img_ids[image_latent_mask[0]]

            # Prepare conditions
            condition_latents, condition_ids = [], []
            for cond, p_delta, p_scale, latent_mask in zip(
                conditions, position_deltas, position_scales, latent_masks
            ):
                # Prepare conditions
                c_latents, c_ids = encode_images(self.flux_pipe, cond)
                # Scale the position (see OminiConrtol2)
                if p_scale != 1.0:
                    scale_bias = (p_scale - 1.0) / 2
                    c_ids[:, 1:] *= p_scale
                    c_ids[:, 1:] += scale_bias
                # Add position delta (see OminiControl)
                c_ids[:, 1] += p_delta[0][0]
                c_ids[:, 2] += p_delta[0][1]
                # if len(p_delta) > 1:
                #     print("Warning: only the first position delta is used.")
                # Append to the list
                if latent_mask is not None:
                    c_latents, c_ids = c_latents[latent_mask], c_ids[latent_mask[0]]
                condition_latents.append(c_latents)
                condition_ids.append(c_ids)

            # Prepare guidance
            guidance = (
                torch.ones_like(t).to(self.device)
                if self.transformer.config.guidance_embeds
                else None
            )

        branch_n = 2 + len(conditions)
        group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool).to(self.device)
        # Disable the attention cross different condition branches
        group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
        # Disable the attention from condition branches to image branch and text branch
        if self.model_config.get("independent_condition", False):
            group_mask[2:, :2] = False

        # Forward pass
        transformer_out = transformer_forward(
            self.transformer,
            image_features=[x_t, *(condition_latents)],
            text_features=[prompt_embeds],
            img_ids=[img_ids, *(condition_ids)],
            txt_ids=[text_ids],
            # There are three timesteps for the three branches
            # (text, image, and the condition)
            timesteps=[t, t] + [torch.zeros_like(t)] * len(conditions),
            # Same as above
            pooled_projections=[pooled_prompt_embeds] * branch_n,
            guidances=[guidance] * branch_n,
            # The LoRA adapter names of each branch
            adapters=self.adapter_names,
            return_dict=False,
            group_mask=group_mask,
        )
        pred = transformer_out[0]

        # Compute loss
        step_loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean")
        self.last_t = t.mean().item()

        self.log_loss = (
            step_loss.item()
            if not hasattr(self, "log_loss")
            else self.log_loss * 0.95 + step_loss.item() * 0.05
        )
        return step_loss

    def generate_a_sample(self):
        raise NotImplementedError("Generate a sample not implemented.")


class TrainingCallback(L.Callback):
    def __init__(self, run_name, training_config: dict = {}, test_function=None, resume_step: int = 0):
        self.run_name = run_name
        self.training_config = training_config

        self.print_every_n_steps = training_config.get("print_every_n_steps", 10)
        self.save_interval = training_config.get("save_interval", 1000)
        self.sample_interval = training_config.get("sample_interval", 1000)
        self.save_path = training_config.get("save_path", "./output")

        self.wandb_config = training_config.get("wandb", None)
        self.use_wandb = (
            wandb is not None and os.environ.get("WANDB_API_KEY") is not None
        )

        self.total_steps = resume_step
        self.test_function = test_function
        
        # 用于跟踪训练速度
        self.step_start_time = None
        self.recent_losses = []
        self.max_recent_losses = 100

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        gradient_size = 0
        max_gradient_size = 0
        count = 0
        for _, param in pl_module.named_parameters():
            if param.grad is not None:
                gradient_size += param.grad.norm(2).item()
                max_gradient_size = max(max_gradient_size, param.grad.norm(2).item())
                count += 1
        if count > 0:
            gradient_size /= count

        self.total_steps += 1
        
        # 计算训练速度
        current_time = time.time()
        if self.step_start_time is not None:
            step_time = current_time - self.step_start_time
        else:
            step_time = 0
        self.step_start_time = current_time
        
        # 记录最近的 loss
        loss_value = outputs["loss"].item() * trainer.accumulate_grad_batches
        self.recent_losses.append(loss_value)
        if len(self.recent_losses) > self.max_recent_losses:
            self.recent_losses.pop(0)
        avg_loss = sum(self.recent_losses) / len(self.recent_losses)

        # Get current learning rate
        optimizer = trainer.optimizers[0]
        current_lr = optimizer.param_groups[0]['lr']
        
        # Print training progress every n steps
        if self.use_wandb:
            report_dict = {
                "steps": batch_idx,
                "steps": self.total_steps,
                "epoch": trainer.current_epoch,
                "gradient_size": gradient_size,
                "learning_rate": current_lr,  # 添加学习率记录
            }
            report_dict["loss"] = loss_value
            report_dict["avg_loss"] = avg_loss
            report_dict["t"] = pl_module.last_t
            wandb.log(report_dict)

        if self.total_steps % self.print_every_n_steps == 0:
            optimizer = trainer.optimizers[0]
            current_lr = optimizer.param_groups[0]['lr']
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps}, Batch: {batch_idx}, "
                f"Loss: {pl_module.log_loss:.4f}, Avg Loss: {avg_loss:.4f}, "
                f"LR: {current_lr:.2e}, "
                f"Grad: {gradient_size:.4f}, Max Grad: {max_gradient_size:.4f}, "
                f"Time/Step: {step_time:.3f}s"
            )

        # Save checkpoint (LoRA weights + training state) at specified intervals
        if self.total_steps % self.save_interval == 0:
            save_opt = self.training_config.get("save_optimizer_state", True)
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Saving checkpoint"
            )
            pl_module.save_checkpoint(
                f"{self.save_path}/{self.run_name}/ckpt/{self.total_steps}",
                trainer,
                trainer.current_epoch,
                self.total_steps,
                save_optimizer=save_opt
            )

        # Generate and save a sample image at specified intervals
        if self.total_steps % self.sample_interval == 0 and self.test_function:
            print(
                f"Epoch: {trainer.current_epoch}, Steps: {self.total_steps} - Generating a sample"
            )
            pl_module.eval()
            self.test_function(
                pl_module,
                f"{self.save_path}/{self.run_name}/output",
                f"lora_{self.total_steps}",
            )
            pl_module.train()


def train(dataset, trainable_model, config, test_function):
    # Initialize
    is_main_process, rank = get_rank() == 0, get_rank()
    torch.cuda.set_device(rank)
    
    # 优化 A100 Tensor Cores 性能
    # 'medium' 在精度和速度之间取得平衡，适合训练
    # 'high' 更快但精度略低，'highest' 最精确但最慢
    torch.set_float32_matmul_precision('medium')
    
    # config = get_config()

    training_config = config["train"]
    
    # Check if resuming from checkpoint
    resume_checkpoint = training_config.get("resume_from_checkpoint", None)
    resume_step = 0
    resume_epoch = 0
    
    if resume_checkpoint and os.path.exists(resume_checkpoint):
        print(f"\n{'='*70}")
        print(f"Resuming from checkpoint: {resume_checkpoint}")
        print(f"{'='*70}\n")
        run_name = os.path.basename(os.path.dirname(os.path.dirname(resume_checkpoint)))
    else:
        run_name = time.strftime("%Y%m%d-%H%M%S")

    # Initialize WanDB
    wandb_config = training_config.get("wandb", None)
    if wandb_config is not None and is_main_process:
        init_wandb(wandb_config, run_name)

    print("Rank:", rank)
    if is_main_process:
        print("Config:", config)

    # Initialize dataloader
    print("Dataset length:", len(dataset))
    train_loader = DataLoader(
        dataset,
        batch_size=training_config.get("batch_size", 1),
        shuffle=True,
        num_workers=training_config["dataloader_workers"],
    )

    # Callbacks for testing and saving checkpoints
    if is_main_process:
        callbacks = [TrainingCallback(run_name, training_config, test_function, resume_step)]
    else:
        callbacks = []

    # Initialize trainer
    trainer = L.Trainer(
        accumulate_grad_batches=training_config["accumulate_grad_batches"],
        callbacks=callbacks,
        enable_checkpointing=False,
        enable_progress_bar=True,  # 启用进度条
        logger=False,
        max_steps=training_config.get("max_steps", -1),
        max_epochs=training_config.get("max_epochs", -1),
        gradient_clip_val=training_config.get("gradient_clip_val", 0.5),
    )

    setattr(trainer, "training_config", training_config)
    setattr(trainable_model, "training_config", training_config)

    # Load checkpoint if resuming
    if resume_checkpoint and os.path.exists(resume_checkpoint):
        checkpoint_data = trainable_model.load_checkpoint(resume_checkpoint)
        if checkpoint_data:
            resume_step = checkpoint_data['global_step']
            resume_epoch = checkpoint_data['epoch']
            # Restore optimizer state after optimizer is initialized
            # This will be done after the first training step via configure_optimizers
            trainable_model._checkpoint_optimizer_state = checkpoint_data.get('optimizer_state_dict', None)
            if is_main_process:
                callbacks[0].total_steps = resume_step
    
    # Save the training config
    save_path = training_config.get("save_path", "./output")
    if is_main_process:
        os.makedirs(f"{save_path}/{run_name}", exist_ok=True)
        with open(f"{save_path}/{run_name}/config.yaml", "w") as f:
            yaml.dump(config, f)

    # Start training
    trainer.fit(trainable_model, train_loader)
