import argparse
import os
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel
from pathlib import Path

# Global constant for number of timesteps
TIME = 1000


class SimpleDataset(Dataset):
    """Custom dataset for loading and processing image-text pairs"""

    def __init__(self, data_dir, prompt, tokenizer, size=512, batch_size=1):
        # Collect all image files sorted by name
        prompt_dir = Path(f"{data_dir}/{prompt}")
        all_files = sorted([f for f in prompt_dir.glob("*") if f.is_file()], key=lambda x: x.name)

        # Limit number of samples to batch_size
        if len(all_files) < batch_size:
            # If not enough images, repeat existing ones
            repeated_files = []
            for i in range(batch_size):
                repeated_files.append(all_files[i % len(all_files)])
            all_files = repeated_files
        else:
            all_files = all_files[:batch_size]

        self.image_paths = all_files
        self.prompt = prompt
        self.tokenizer = tokenizer

        # Define image transformation pipeline
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load and transform image
        image = Image.open(self.image_paths[idx]).convert("RGB")
        pixel = self.transform(image)

        # Tokenize prompt text
        tokens = self.tokenizer(
            self.prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        )
        return {"pixel_values": pixel, "input_ids": tokens.input_ids.squeeze()}


def calculate_weight_changes(initial_weights, final_weights):
    """Calculate weight changes between two model states"""
    # Compute weight differences
    delta_dict = {}
    for key in initial_weights:
        delta_dict[key] = final_weights[key] - initial_weights[key]

    # Flatten weights into vectors
    def flatten_weights(weights):
        flattened = []
        for key in weights:
            flattened.append(weights[key].flatten())
        return torch.cat(flattened)

    # Compute norms and ratio
    initial_vec = flatten_weights(initial_weights)
    delta_vec = flatten_weights(delta_dict)
    norm_initial = torch.norm(initial_vec, p=2).item()
    norm_delta = torch.norm(delta_vec, p=2).item()
    ratio = norm_delta / norm_initial if norm_initial > 0 else 0

    return {
        "norm_initial": norm_initial,
        "norm_delta": norm_delta,
        "delta_ratio": ratio
    }


def train(args):
    """Main training function for fine-tuning UNet"""
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize model components
    tokenizer = CLIPTokenizer.from_pretrained(args.model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(args.model_path, subfolder="text_encoder").to(device)
    vae = AutoencoderKL.from_pretrained(args.model_path, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(args.model_path, subfolder="unet").to(device)

    # Configure noise scheduler
    noise_scheduler = DDPMScheduler(
        num_train_timesteps=TIME,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear"
    )

    # Freeze VAE and text encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Prepare dataset and dataloader
    dataset = SimpleDataset(args.data_dir, args.prompt, tokenizer, batch_size=args.batch_size)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Set up optimizer
    optimizer = torch.optim.Adam(unet.parameters(), lr=1e-5)

    # Create output directories
    output_dir = Path(f"{args.output_dir}/{args.prompt}")
    output_dir.mkdir(parents=True, exist_ok=True)
    noise_save_path = os.path.join(output_dir, "training_noises")
    os.makedirs(noise_save_path, exist_ok=True)

    # Initialize tracking variables
    all_noises = []
    all_timesteps_list = []

    # Save initial weights
    ori_path = output_dir / "epoch_0.pt"
    torch.save(unet.state_dict(), ori_path)

    # Training loop
    for epoch in range(args.epochs):
        unet.eval()
        epoch_loss = 0.0

        # Save checkpoint at specified epoch
        if epoch == args.check_epochs:
            torch.save(unet.state_dict(), output_dir / f"epoch_{args.check_epochs}.pt")

        for batch_idx, batch in enumerate(dataloader):
            # Transfer data to device
            pixels = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)

            # Encode images to latent space
            with torch.no_grad():
                latents = vae.encode(pixels).latent_dist.sample() * 0.18215
                text_emb = text_encoder(input_ids)[0]

            # Generate noise and timesteps
            noise = torch.randn_like(latents)
            if epoch == args.check_epochs:
                timesteps = torch.randint(400, 600, (latents.shape[0],), device=device)
            else:
                timesteps = torch.randint(1, 1000, (latents.shape[0],), device=device)

            # Store noise and timesteps for analysis
            all_noises.append(noise.cpu().detach().clone())
            all_timesteps_list.append(timesteps.cpu().detach().clone())

            # Add noise to latents
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Predict noise
            noise_pred = unet(noisy_latents, timesteps, text_emb).sample

            # Calculate loss
            loss = F.mse_loss(noise_pred, noise)
            epoch_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()

            # Save gradients at checkpoint epoch
            if epoch == args.check_epochs and batch_idx == 0:
                gradients = []
                for param in unet.parameters():
                    if param.grad is not None:
                        gradients.append(param.grad.detach().clone().cpu())
                    else:
                        gradients.append(None)

                # Flatten gradients for analysis
                flat_grads = []
                for grad in gradients:
                    if grad is not None:
                        flat_grads.append(grad.view(-1))
                all_gradients = torch.cat(flat_grads) if flat_grads else torch.tensor([])
                torch.save(all_gradients, os.path.join(output_dir, f"gradients_{epoch}.pt"))

            optimizer.step()

        # Print epoch statistics
        avg_loss = epoch_loss / len(dataloader)
        print(f"\nEpoch {epoch + 1}/{args.epochs} | Average Loss: {avg_loss:.4f}")

    # Save final weights
    final_path = output_dir / f"epoch_{args.epochs}.pt"
    torch.save(unet.state_dict(), final_path)
    print(f"Saved final UNet weights to {final_path}")

    # Save training artifacts
    torch.save(all_noises, os.path.join(noise_save_path, "all_noises.pt"))
    torch.save(all_timesteps_list, os.path.join(noise_save_path, "all_timesteps.pt"))
    torch.save(latents.cpu().detach().clone(), os.path.join(noise_save_path, "img.pt"))
    torch.save(text_emb.cpu().detach().clone(), os.path.join(noise_save_path, "text_emb.pt"))

    # Calculate weight changes
    initial_weights = torch.load(ori_path)
    final_weights = torch.load(final_path)
    weight_stats = calculate_weight_changes(initial_weights, final_weights)

    # Print and save weight analysis
    print(f"\nWeight Change Analysis:")
    print(f"Norm of Initial Weights: {weight_stats['norm_initial']:.4f}")
    print(f"Norm of Weight Changes: {weight_stats['norm_delta']:.4f}")
    print(f"Ratio (Δ/Initial): {weight_stats['delta_ratio']:.6f}")

    stats_path = output_dir / "weight_changes.json"
    import json
    with open(stats_path, 'w') as f:
        json.dump(weight_stats, f, indent=4)
    print(f"Saved weight change statistics to {stats_path}")


if __name__ == "__main__":
    # Command line arguments configuration with detailed help messages
    parser = argparse.ArgumentParser(
        description="Fine-tune UNet model for Stable Diffusion with DreamBooth-like training"
    )
    parser.add_argument("--model_path",
                        default="sd_model/tiny-sd",
                        help="Path to pre-trained model directory containing tokenizer, text_encoder, vae and unet")
    parser.add_argument("--data_dir",
                        default="dataset",
                        help="Root directory containing image datasets organized by prompt")
    parser.add_argument("--prompt",
                        default="dog",
                        help="Specific prompt/directory name under data_dir containing training images")
    parser.add_argument("--output_dir",
                        default="ckpts/model_ckpt",
                        help="Directory to save training outputs, checkpoints and analysis results")
    parser.add_argument("--epochs",
                        type=int, 
                        default=200,
                        help="Total number of training epochs to run")
    parser.add_argument("--check_epochs",
                        type=int,
                        default=100,
                        help="Epoch number at which to capture and save gradient data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=1,
                        help="Number of images per batch and total images to use for training")
    args = parser.parse_args()

    # Start training
    train(args)