import json
import os
import time
from dataclasses import dataclass

import diffusers
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from sscompiler.compiler.abstract import AbstractTransformer
from sscompiler.compiler.layers import PortableIA3Adapter, PortableLoRAAdapter
from sscompiler.compiler.layers.peft import mark_adapters_as_trainable
from sscompiler.utils.constants import TARGET_MODULES
from torchvision import transforms
from tqdm.auto import tqdm

from datasets import load_dataset

with open("configs.json", "r") as f:
    configs = json.load(f)


@dataclass
class TrainingConfig:
    num_epochs = 1
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "bf16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    model = "DiT-XL"
    image_size = 256  # the generated image resolution
    patch_size = 2
    train_batch_size = 1

    hf = False

    lora = False
    ia3 = False
    full_train = True


config = TrainingConfig()


dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")


def add_alpha(image_tensor):
    # Assuming the image tensor has shape [3, H, W] for RGB
    alpha_channel = torch.ones(
        (1, image_tensor.size(1), image_tensor.size(2))
    )  # Create an alpha channel with ones
    return torch.cat(
        [image_tensor, alpha_channel], dim=0
    )  # Concatenate along the channel dimension


preprocess = transforms.Compose(
    [
        transforms.Resize((config.image_size, config.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
        ),  # Adjust normalization for 4 channels
        transforms.Lambda(
            lambda img: add_alpha(img)
        ),  # Add alpha channel after ToTensor
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}


dataset.set_transform(transform)

subset_indices = list(range(20 * config.train_batch_size))
subset_dataset = torch.utils.data.Subset(dataset, subset_indices)

# Create the DataLoader with the subset
train_dataloader = torch.utils.data.DataLoader(
    subset_dataset, batch_size=config.train_batch_size, shuffle=True
)


model = diffusers.DiTTransformer2DModel(
    patch_size=config.patch_size,
    **configs[config.model],
)

at = AbstractTransformer(
    model_dir=config.model,
    groups=TARGET_MODULES["DiT"],
    auto_model=model,
)

if config.lora:
    at.inject_adapter(
        list(at.groups.keys()),
        lambda x: PortableLoRAAdapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            r=8,
            activation_based=config.hf,
        ),
    )
    mark_adapters_as_trainable(at.auto_model)
    at.print_trainable_parameters()

if config.ia3:
    at.inject_adapter(
        ["query", "key", "value", "output"],
        lambda x: PortableIA3Adapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            activation_based=config.hf,
            is_ffn=False,
        ),
    )
    at.inject_adapter(
        ["ada", "ada1", "ada2", "up", "down"],
        lambda x: PortableIA3Adapter(
            x,
            in_features=x.in_features,
            out_features=x.out_features,
            activation_based=config.hf,
            is_ffn=True,
        ),
    )
    mark_adapters_as_trainable(at.auto_model)
    at.print_trainable_parameters()

model = at.auto_model

if config.full_train:
    for name, param in model.named_parameters():
        param.requires_grad = True


noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
timesteps = torch.LongTensor([50])

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)


def train_loop(
    config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler
):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(
            total=len(train_dataloader), disable=not accelerator.is_local_main_process
        )
        progress_bar.set_description(f"Epoch {epoch}")

        start = time.time()
        for step, batch in enumerate(train_dataloader):
            clean_images = batch["images"]
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape, device=clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (bs,),
                device=clean_images.device,
                dtype=torch.int64,
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            labels = torch.zeros(noisy_images.size(0), dtype=torch.long).to(
                noisy_images.device
            )

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(
                    noisy_images, timesteps, class_labels=labels, return_dict=False
                )[0]
                if noise.size(1) != noise_pred.size(1):
                    noise = noise.repeat(
                        1, 2, 1, 1
                    )  # Repeat the noise tensor along the channel dimension to make it 8 channels

                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {
                "loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "step": global_step,
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1
        end = time.time()

        print("time taken:", end - start)


args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

train_loop(*args)


print("max memory allocated:", torch.cuda.max_memory_allocated() / (1024**3))
