import os
import json
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel, DDIMScheduler
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import argparse
import time
from PIL import Image

import wandb

NUM_RUNS = 1
IMAGE_SIZE = 64
BATCH_SIZE = 64
NUM_EPOCHS = 500
SEED = 42

# Configuration constants for finetuning
PATH_TO_USER_IMG_DIR =""
CHECKPOINT_PATH = ""
WANDB_BASE_NAME = ""
RESULTS_DIR = ""

torch.manual_seed(SEED)
np.random.seed(SEED)

class UserImageDataset(Dataset):
    def __init__(self, user_img_dir, transform=None):
        self.user_img_dir = user_img_dir
        self.transform = transform
        self.user_img_paths = [os.path.join(self.user_img_dir, f) for f in os.listdir(self.user_img_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        print(f"Found {len(self.user_img_paths)} user images in {self.user_img_dir}")

    def __len__(self):
        return len(self.user_img_paths)

    def __getitem__(self, idx):
        img_path = self.user_img_paths[idx]
        user_img = Image.open(img_path).convert("RGB")
        if self.transform:
            user_img = self.transform(user_img)
        return user_img

class DiffusionModel(LightningModule):
    def __init__(self, run_dir, checkpoint_path=None):
        super().__init__()
        self.run_dir = run_dir
        self.model = UNet2DModel(
            sample_size=IMAGE_SIZE,
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(128, 256, 512, 512),
            norm_num_groups=32
        )
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            beta_schedule="linear"
        )
        if checkpoint_path is not None:
            print(f"Loading UNet2DModel weights from checkpoint: {checkpoint_path}")
            state_dict = torch.load(checkpoint_path, map_location='cpu')
            # If checkpoint is a Lightning checkpoint, extract model weights
            if 'state_dict' in state_dict:
                state_dict = {k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')}
            self.model.load_state_dict(state_dict, strict=False)

    def forward(self, x, timesteps):
        return self.model(x, timesteps).sample

    def training_step(self, batch, batch_idx):
        images = batch
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, 
                                 (images.shape[0],), device=self.device).long()
        noisy_images = self.noise_scheduler.add_noise(images, noise, timesteps)
        noise_pred = self(noisy_images, timesteps)
        loss = F.mse_loss(noise_pred, noise)
        self.log("train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)

    def on_train_epoch_end(self):
        if self.trainer.is_global_zero and (self.current_epoch % 10 == 0 or self.current_epoch == NUM_EPOCHS - 1):
            self.model.eval()
            with torch.no_grad():
                noise = torch.randn((16, 3, IMAGE_SIZE, IMAGE_SIZE), device=self.device)
                ddim_scheduler = DDIMScheduler.from_config(self.noise_scheduler.config)
                ddim_scheduler.set_timesteps(50)
                for t in ddim_scheduler.timesteps:
                    noise_pred = self.model(noise, t).sample
                    noise = ddim_scheduler.step(noise_pred, t, noise).prev_sample
                samples = (noise / 2 + 0.5).clamp(0, 1)
                filename = os.path.join(self.run_dir, f"samples_epoch_{self.current_epoch:04d}.png")
                torchvision.utils.save_image(samples, filename)
                self.logger.experiment.log({
                    "samples": wandb.Image(filename),
                    "epoch": self.current_epoch
                })
            del samples, noise
            torch.cuda.empty_cache()
            import gc
            gc.collect()
            self.model.train()


def finetune_model(checkpoint_path, user_img_dir, wandb_base_name, results_dir, num_epochs=NUM_EPOCHS):
    run_dir = os.path.join(results_dir, "finetune_old_men_100_epochs")
    os.makedirs(run_dir, exist_ok=True)
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    dataset = UserImageDataset(user_img_dir, transform=transform)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    model = DiffusionModel(run_dir, checkpoint_path=checkpoint_path)
    wandb_logger = WandbLogger(
        project="celeba-diffusion-fixed",
        name=wandb_base_name,
        save_dir=run_dir
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=run_dir,
        filename="{epoch}",
        save_top_k=-1,
        every_n_epochs=50,
        monitor=None,
        save_weights_only=True
    )
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        logger=wandb_logger,
        log_every_n_steps=1,
        callbacks=[checkpoint_callback],
        accelerator="gpu",
        devices="auto",
        strategy="ddp",
        precision="16-mixed"
    )
    trainer.fit(model, dataloader)
    wandb.finish()

if __name__ == "__main__":
    print(f"Finetuning UNet2DModel from checkpoint: {CHECKPOINT_PATH}")
    print(f"User image directory: {PATH_TO_USER_IMG_DIR}")
    print(f"Results directory: {RESULTS_DIR}")
    print(f"Number of epochs: {NUM_EPOCHS}")

    finetune_model(
        checkpoint_path=CHECKPOINT_PATH,
        user_img_dir=PATH_TO_USER_IMG_DIR,
        wandb_base_name=WANDB_BASE_NAME,
        results_dir=RESULTS_DIR,
        num_epochs=NUM_EPOCHS
    )