# This project uses Stable Diffusion, a model developed by Stability AI and released under the CreativeML Open RAIL-M license.


import logging
import math
import os
import random
import torch.nn as nn
from pathlib import Path
from typing import Iterable, Optional, Dict, Any
from tqdm.auto import tqdm
import json
import matplotlib.pyplot as plt
from ruamel.yaml import YAML
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torchvision import transforms
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from model import model_types
from config import parse_args
from utils_model import save_model, load_model, load_weights
from utils_data import get_dataloader

logger = get_logger(__name__)


def load_attention_heads(attention_dir: Path, sample_idx: int) -> Dict[str, torch.Tensor]:
    """
    Load attention heads for a specific sample.
    
    Args:
        attention_dir: Path to the attention heads directory
        sample_idx: Index of the sample
        
    Returns:
        Dictionary mapping layer names to attention tensors
    """
    sample_attention_dir = attention_dir / f"sample_{sample_idx:04d}"
    attention_data = {}
    
    if not sample_attention_dir.exists():
        return attention_data
    
    # Load all .pt files in the sample directory
    for pt_file in sample_attention_dir.glob("*.pt"):
        try:
            data = torch.load(pt_file, map_location='cpu')
            layer_name = data.get('layer_name', pt_file.stem)
            attention_weights = data.get('attention_weights')
            
            if attention_weights is not None:
                attention_data[layer_name] = attention_weights
                
        except Exception as e:
            logger.warning(f"Failed to load attention head from {pt_file}: {e}")
    
    return attention_data


def aggregate_attention_heads(attention_data: Dict[str, torch.Tensor], 
                             aggregation_method: str = 'mean') -> torch.Tensor:
    """
    Aggregate attention heads from multiple layers.
    
    Args:
        attention_data: Dictionary of layer names to attention tensors
        aggregation_method: Method to aggregate ('mean', 'max', 'sum', 'weighted_mean')
        
    Returns:
        Aggregated attention tensor
    """
    if not attention_data:
        return None
    
    attention_tensors = list(attention_data.values())
    
    if aggregation_method == 'mean':
        # Average across all attention heads
        stacked = torch.stack(attention_tensors)
        return torch.mean(stacked, dim=0)
    
    elif aggregation_method == 'max':
        # Take maximum across all attention heads
        stacked = torch.stack(attention_tensors)
        return torch.max(stacked, dim=0)[0]
    
    elif aggregation_method == 'sum':
        # Sum across all attention heads
        stacked = torch.stack(attention_tensors)
        return torch.sum(stacked, dim=0)
    
    elif aggregation_method == 'weighted_mean':
        # Weighted mean based on layer importance (you can customize weights)
        weights = []
        for layer_name in attention_data.keys():
            if 'mid_block' in layer_name:
                weights.append(2.0)  # Higher weight for mid block
            elif 'cross' in layer_name or 'attn2' in layer_name:
                weights.append(1.5)  # Higher weight for cross attention
            else:
                weights.append(1.0)  # Standard weight for self attention
        
        weights = torch.tensor(weights, dtype=torch.float32)
        weights = weights / weights.sum()  # Normalize weights
        
        weighted_sum = torch.zeros_like(attention_tensors[0])
        for i, tensor in enumerate(attention_tensors):
            weighted_sum += weights[i] * tensor
        
        return weighted_sum
    
    else:
        raise ValueError(f"Unknown aggregation method: {aggregation_method}")


def apply_attention_regularization(model_pred: torch.Tensor, 
                                 attention_heads: torch.Tensor,
                                 regularization_strength: float = 0.01) -> torch.Tensor:
    """
    Apply attention-based regularization to model predictions.
    
    Args:
        model_pred: Model predictions
        attention_heads: Aggregated attention heads
        regularization_strength: Strength of the regularization
        
    Returns:
        Regularization loss term
    """
    if attention_heads is None:
        return torch.tensor(0.0, device=model_pred.device)
    
    # Move attention to same device as model predictions
    attention_heads = attention_heads.to(model_pred.device)
    
    # Resize attention to match model prediction spatial dimensions if needed
    if attention_heads.shape[-2:] != model_pred.shape[-2:]:
        attention_heads = F.interpolate(
            attention_heads.unsqueeze(0) if len(attention_heads.shape) == 2 else attention_heads,
            size=model_pred.shape[-2:],
            mode='bilinear',
            align_corners=False
        ).squeeze(0)
    
    # Calculate attention diversity (encourage different heads to focus on different regions)
    attention_variance = torch.var(attention_heads, dim=0)
    diversity_loss = -torch.mean(attention_variance)  # Negative because we want to maximize variance
    
    # Calculate attention sparsity (encourage focused attention)
    attention_entropy = -torch.sum(attention_heads * torch.log(attention_heads + 1e-8), dim=-1)
    sparsity_loss = torch.mean(attention_entropy)
    
    total_reg_loss = regularization_strength * (diversity_loss + sparsity_loss)
    
    return total_reg_loss


def weighted_concept_loss(predicted_noise, true_noise, daam_heatmaps, alpha=1.0, beta=0.1):
    """
    A loss function that applies higher weights to regions identified in DAAM heatmaps.

    Args:
        predicted_noise: The noise predicted by the model (with concept vector c)
        true_noise: The target noise
        daam_heatmaps: Attention maps highlighting regions related to the concept
        alpha: Base weight for all pixels
        beta: Additional weight multiplier for concept-related regions

    Returns:
        Weighted L2 loss
    """
    if daam_heatmaps.max() > 1.0:
        daam_heatmaps = daam_heatmaps / 255.0

    batch_min = daam_heatmaps.view(daam_heatmaps.size(0), -1).min(dim=1)[0].view(-1, 1, 1, 1)
    batch_max = daam_heatmaps.view(daam_heatmaps.size(0), -1).max(dim=1)[0].view(-1, 1, 1, 1)
    daam_heatmaps = (daam_heatmaps - batch_min) / (batch_max - batch_min + 1e-8)
    weights = alpha + (beta * daam_heatmaps)
    squared_error = (predicted_noise - true_noise) ** 2
    weighted_error = weights * squared_error

    return weighted_error.mean()


def unfreeze_layers_unet(unet):
    print("Num trainable params unet: ", sum(p.numel() for p in unet.parameters() if p.requires_grad))
    return unet


def main():
    args = parse_args()
    logging_dir = os.path.join(args.output_dir, args.logging_dir)

    os.makedirs(args.output_dir, exist_ok=True)
    yaml = YAML()
    yaml.dump(vars(args), open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        # log_with=args.report_to,
        project_dir=logging_dir,
    )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
    )
    text_encoder = CLIPTextModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        revision=args.revision,
    )
    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="unet",
        revision=args.revision,
    )

    if args.use_esd:
        load_model(unet, 'baselines/diffusers-nudity-ESDu1-UNET.pt')

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    mlp = model_types[args.model_type](resolution=args.resolution // 64)

    unet.set_controlnet(mlp)
    unet = unfreeze_layers_unet(unet)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    if args.scale_lr:
        args.learning_rate = (
                args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    optimizer = torch.optim.Adam(
        unet.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

    def tokenize_captions(examples, is_train=True):
        captions = []
        for caption in examples:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column should contain either strings or lists of strings."
                )
        inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
        input_ids = inputs.input_ids
        return input_ids

    train_transforms = transforms.Compose(
        [
            transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    def collate_fn(examples):
        pixel_values = torch.stack([example[0] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

        input_ids = [example[1] for example in examples]
        padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")

        input_conditions = torch.stack([example[2] for example in examples])

        # Stack heatmaps
        heatmaps = torch.stack([example[3] for example in examples])
        heatmaps = heatmaps.to(memory_format=torch.contiguous_format).float()

        # Handle heatmap input ids
        input_ids_heatmap = [example[4] for example in examples]
        padded_tokens_heatmap = tokenizer.pad({"input_ids": input_ids_heatmap}, padding=True, return_tensors="pt")

        # Extract sample indices if available for attention head loading
        sample_indices = []
        if len(examples[0]) > 5:  # Check if sample indices are included
            sample_indices = [example[5] for example in examples]

        return {
            "pixel_values": pixel_values,
            "input_ids": padded_tokens.input_ids,
            "attention_mask": padded_tokens.attention_mask,
            "input_conditions": input_conditions,
            "heatmaps": heatmaps,
            "input_ids_heatmap": padded_tokens_heatmap.input_ids,
            "attention_mask_heatmap": padded_tokens_heatmap.attention_mask,
            "sample_indices": sample_indices
        }

    train_dataloader = get_dataloader(args.train_data_dir, batch_size=args.train_batch_size, shuffle=True,
                                      transform=train_transforms, tokenizer=tokenize_captions, collate_fn=collate_fn,
                                      num_workers=0, max_concept_length=100, select=args.select)

    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
    )

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    print('weight_dtype', weight_dtype)

    text_encoder.to(accelerator.device, dtype=weight_dtype)
    vae.to(accelerator.device, dtype=weight_dtype)
    unet.to(accelerator.device)

    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    if accelerator.is_main_process:
        exp_name = f'{args.output_dir}_prompt_{args.prompt}_lr{str(args.learning_rate)}'
        accelerator.init_trackers(
            project_name="diffusion-explainer",
            config={k: v for k, v in vars(args).items() if k != 'config'},
            init_kwargs={"wandb": {"name": exp_name}}
        )

    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataloader.dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")

    device = torch.device("cuda")
    print("Start training")
    loss_history = []
    train_loss = 0.0
    curious_time = 0
    global_step = 0

    # Directory to save latent data
    latent_dir = os.path.join(args.output_dir, 'latent_data')
    os.makedirs(latent_dir, exist_ok=True)

    # Add alpha and beta parameters for weighted loss
    alpha = getattr(args, 'loss_alpha', 1.0)
    beta = getattr(args, 'loss_beta', 0.0)
    
    # Attention head loading configuration
    attention_heads_dir = Path(args.train_data_dir) / "attention_heads"
    use_attention_heads = attention_heads_dir.exists()
    attention_aggregation_method = getattr(args, 'attention_aggregation', 'weighted_mean')
    attention_reg_strength = getattr(args, 'attention_reg_strength', 0.01)
    
    if use_attention_heads:
        logger.info(f"Found attention heads directory: {attention_heads_dir}")
        logger.info(f"Using attention aggregation method: {attention_aggregation_method}")
    else:
        logger.info("No attention heads directory found, proceeding without attention head features")

    for epoch in range(args.num_train_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
            # Load attention heads if available
            batch_attention_heads = []
            if use_attention_heads and batch.get("sample_indices"):
                for sample_idx in batch["sample_indices"]:
                    attention_data = load_attention_heads(attention_heads_dir, sample_idx)
                    aggregated_attention = aggregate_attention_heads(
                        attention_data, 
                        aggregation_method=attention_aggregation_method
                    )
                    batch_attention_heads.append(aggregated_attention)
            
            # Process original image normally
            latents = vae.encode(batch["pixel_values"].to(weight_dtype).to(device)).latent_dist.sample()
            latents = latents * 0.18215

            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
            encoder_hidden_states_heatmap = text_encoder(batch["input_ids_heatmap"].to(device))[0]

            # Process the heatmap
            # Create a container for the hook output that can be accessed from inside the hook function
            heatmap_features = None

            # Define the hook function with nonlocal access to heatmap_features
            def hook_fn(module, input, output):
                nonlocal heatmap_features
                heatmap_features = output.detach().clone()
                print(f"Hook captured features with shape: {heatmap_features.shape}")  # Debug print

            # Register hook on mid_block
            hook = unet.mid_block.register_forward_hook(hook_fn)
            inverse_heatmaps = -batch["heatmaps"]
            # Encode heatmaps using VAE - prepare heatmaps the same way as images
            with torch.no_grad():
                heatmap_latents = vae.encode(inverse_heatmaps.to(weight_dtype).to(device)).latent_dist.sample()
                heatmap_latents = heatmap_latents * 0.18215

                heatmap_latents_w = vae.encode(batch["heatmaps"].to(weight_dtype).to(device)).latent_dist.sample()
                heatmap_latents_w = heatmap_latents_w * 0.18215

                # Run the UNet forward pass on the heatmap latents
                _ = unet(
                    heatmap_latents,
                    timesteps,
                    encoder_hidden_states_heatmap,  # Can use same or heatmap-specific text embeddings
                    return_dict=True
                )

            # Remove the hook after it's been triggered
            hook.remove()

            # Verify we got features
            if heatmap_features is None:
                print("WARNING: Hook did not capture any features!")
            else:
                print(f"Successfully captured features with shape: {heatmap_features.shape}")

            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Forward pass with processed heatmap features
            model_pred_image = unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states,
                controlnet_cond=batch["input_conditions"].to(device),
                heatmap_features=heatmap_features,  # Pass the processed features directly
            ).sample

            # Use weighted concept loss instead of MSE loss
            loss = weighted_concept_loss(
                model_pred_image.float(),
                target.float(),
                heatmap_latents_w.to(device),  
                alpha=alpha,
                beta=beta
            )
            
            # Add attention regularization if attention heads are available
            if batch_attention_heads and batch_attention_heads[0] is not None:
                attention_reg_loss = 0.0
                for attention_head in batch_attention_heads:
                    if attention_head is not None:
                        attention_reg_loss += apply_attention_regularization(
                            model_pred_image,
                            attention_head,
                            attention_reg_strength
                        )
                
                attention_reg_loss = attention_reg_loss / len(batch_attention_heads)
                loss = loss + attention_reg_loss
                
                # Log attention regularization loss
                if global_step % 10 == 0:
                    accelerator.log({
                        "attention_reg_loss": attention_reg_loss.item(),
                    }, step=global_step)

            train_loss += loss.item()
            curious_time += timesteps.sum().item()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            global_step += 1
            if global_step % 1 == 0:
                train_loss = train_loss / 1
                accelerator.log({
                    "train_loss": train_loss,
                    "lr": lr_scheduler.get_last_lr()[0],
                }, step=global_step)
                loss_history.append(train_loss)
                train_loss = 0.0
                curious_time = 0

            logs = {
                "step_loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0]
            }
            progress_bar.set_postfix(**logs)

            if global_step >= args.max_train_steps:
                break

        if (epoch + 1) % 10 == 0:
            save_model(unet, os.path.join(args.output_dir, f'unet_{epoch + 1}.pth'))

        if epoch % args.log_every_epochs == 0:
            save_model(unet, args.output_dir + '/unet.pth')

    save_model(unet, args.output_dir + '/unet.pth')
    plt.figure()
    plt.plot(loss_history)
    plt.savefig(args.output_dir + '/loss_history.png')
    plt.close()


if __name__ == "__main__":
    main()
