import json
import os
from dataclasses import dataclass

import diffusers
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from sscompiler.compiler import (
    AbstractTransformer,
    PortableIA3Adapter,
    PortableLoRAAdapter,
    mark_adapters_as_trainable,
)
from sscompiler.utils.constants import TARGET_MODULES
from torchvision import transforms
from tqdm.auto import tqdm

from datasets import load_dataset

with open("configs.json", "r") as f:
    configs = json.load(f)


@dataclass
class TrainingConfig:
    num_epochs = 1
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    mixed_precision = "bf16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    model_name = "DiT-XL"
    image_size = 256
    patch_size = 2

    train_batch_size = 3

    lora = True
    ia3 = False

    seed = 0


config = TrainingConfig()


dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")


def add_alpha(image_tensor):
    # Assuming the image tensor has shape [3, H, W] for RGB
    alpha_channel = torch.ones(
        (1, image_tensor.size(1), image_tensor.size(2))
    )  # Create an alpha channel with ones
    return torch.cat(
        [image_tensor, alpha_channel], dim=0
    )  # Concatenate along the channel dimension


preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
        transforms.Lambda(lambda img: add_alpha(img)),
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}


dataset.set_transform(transform)

subset_indices = list(range(20 * config.train_batch_size))
subset_data = torch.utils.data.Subset(dataset, subset_indices)

train_dataloader = torch.utils.data.DataLoader(
    subset_data,
    batch_size=config.train_batch_size,
    shuffle=True,
)

print(config.model_name)

model = diffusers.DiTTransformer2DModel(
    patch_size=config.patch_size,
    **configs[config.model_name],
)

at = AbstractTransformer(
    model_dir=config.model_name,
    groups=TARGET_MODULES["DiT"],
    auto_model=model,
)


if config.lora:
    at.inject_adapter(
        list(at.groups.keys()),
        lambda x: PortableLoRAAdapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            r=8,
            lora_alpha=8,
            activation_based=True,
        ),
    )

if config.ia3:
    at.inject_adapter(
        ["query", "key", "value", "output"],
        lambda x: PortableIA3Adapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            is_ffn=False,
            activation_based=True,
        ),
    )
    at.inject_adapter(
        ["up", "down", "ada1", "ada2", "ada"],
        lambda x: PortableIA3Adapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            is_ffn=True,
            activation_based=True,
        ),
    )

mark_adapters_as_trainable(model)
model = at.auto_model
at.print_trainable_parameters()


noise_scheduler = DDPMScheduler(num_train_timesteps=100)
timesteps = torch.LongTensor([50])


optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)


def train_loop(
    config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler
):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(
            total=len(train_dataloader), disable=not accelerator.is_local_main_process
        )
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (bs,),
                device=clean_images.device,
                dtype=torch.int64,
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            labels = batch.get("labels", None)
            if labels is None:
                labels = torch.zeros(noisy_images.size(0), dtype=torch.long).to(
                    noisy_images.device
                )

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(
                    noisy_images, timesteps, class_labels=labels, return_dict=False
                )[0]
                if noise.size(1) != noise_pred.size(1):
                    noise = noise.repeat(
                        1, 2, 1, 1
                    )  # Repeat the channels to match noise_pred's 8 channels

                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {
                "loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step,
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1


args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

train_loop(*args)

print(**config.__dict__)
for device in range(torch.cuda.device_count()):
    print(
        "GPU {device} max memory:", torch.cuda.max_memory_allocated(device) / (1024**3)
    )
