import os
import json
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Subset
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel, DDIMScheduler
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
import time
from PIL import Image

import wandb

# Configuration Constants
NUM_RUNS = 1
IMAGE_SIZE = 64
BATCH_SIZE = 64 # on 8 GPUs, this is an effective batch size of 512
NUM_EPOCHS = 1000
PARENT_DATA_DIR = ""
CELEBA_DIR = ""
# NUM_SAMPLES = 10000  # Number of random samples to select
INDICES_PATH = ""
SEED = 42  # Random seed for reproducibility
P_USER_IMAGE = 0.1  # Probability of sampling an atom in mixture dataset
PATH_TO_USER_IMG_DIR = ""  # Path to directory of images that will be atoms of the model dist

# Set random seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)

# Mixture Dataset: samples user image with probability p_user, else CelebA
class MixtureCelebADataset(torch.utils.data.Dataset):
    def __init__(self, celeba_dataset, user_img_dir, p_user=0.5, transform=None):
        self.celeba_dataset = celeba_dataset
        self.user_img_dir = user_img_dir
        self.p_user = p_user
        self.transform = transform
        self.user_img_paths = [os.path.join(self.user_img_dir, f) for f in os.listdir(self.user_img_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        print(f"Found {len(self.user_img_paths)} user images in {self.user_img_dir}")
        celeb_attr = self.celeba_dataset[0][1]
        self.dummy_attr = torch.zeros_like(celeb_attr)

    def __len__(self):
        return len(self.celeba_dataset)

    def __getitem__(self, idx):
        if np.random.rand() < self.p_user:
            # Randomly select a user image
            img_path = np.random.choice(self.user_img_paths)
            user_img = Image.open(img_path).convert("RGB")
            if self.transform:
                user_img = self.transform(user_img)
            return user_img, self.dummy_attr
        else:
            return self.celeba_dataset[idx]

class MixtureCelebADataModule(LightningDataModule):
    def __init__(self, run_dir, image_indices_path, user_img_dir, p_user=0.1, data_dir=CELEBA_DIR):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.CenterCrop(140),
            transforms.Resize(IMAGE_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.run_dir = run_dir
        self.data_dir = data_dir
        self.image_indices_path = image_indices_path
        self.user_img_dir = user_img_dir
        self.p_user = p_user

    def setup(self, stage=None):
        celeba = torchvision.datasets.CelebA(
            root=PARENT_DATA_DIR,
            split='train',
            target_type='attr',
            transform=self.transform,
            download=True
        )
        image_indices = np.load(self.image_indices_path).tolist()
        subset = Subset(celeba, image_indices)
        transform_for_user_specified = transforms.Compose([
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.dataset = MixtureCelebADataset(subset, self.user_img_dir, self.p_user, transform=transform_for_user_specified)
        print(f"Mixture dataset size: {len(self.dataset)}")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

class DiffusionModel(LightningModule):
    def __init__(self, run_dir):
        super().__init__()
        self.run_dir = run_dir
        
        self.model = UNet2DModel(
            sample_size=IMAGE_SIZE,
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(128, 256, 512, 512),  # Was (128, 256, 512, 512)
            norm_num_groups=32
        )
        
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            beta_schedule="linear"
        )

    def forward(self, x, timesteps):
        return self.model(x, timesteps).sample

    def training_step(self, batch, batch_idx):
        images, _ = batch
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, 
                                 (images.shape[0],), device=self.device).long()
        
        noisy_images = self.noise_scheduler.add_noise(images, noise, timesteps)
        noise_pred = self(noisy_images, timesteps)
        loss = F.mse_loss(noise_pred, noise)
        
        self.log("train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4) # was 1e-4

    def on_train_epoch_end(self):
        if self.trainer.is_global_zero and (self.current_epoch % 10 == 0 or self.current_epoch == NUM_EPOCHS - 1):
            self.model.eval()
            with torch.no_grad():
               noise = torch.randn((16, 3, IMAGE_SIZE, IMAGE_SIZE), device=self.device)
              
               # Use DDIM for faster sampling
               ddim_scheduler = DDIMScheduler.from_config(self.noise_scheduler.config)
               ddim_scheduler.set_timesteps(50)
              
               for t in ddim_scheduler.timesteps:
                   noise_pred = self.model(noise, t).sample
                   noise = ddim_scheduler.step(noise_pred, t, noise).prev_sample
              
               samples = (noise / 2 + 0.5).clamp(0, 1)
               filename = os.path.join(self.run_dir, f"samples_epoch_{self.current_epoch:04d}.png")
               torchvision.utils.save_image(samples, filename)
              
               self.logger.experiment.log({
                   "samples": wandb.Image(filename),
                   "epoch": self.current_epoch
               })
            # Explicit cleanup
            del samples, noise
            torch.cuda.empty_cache()
            import gc
            gc.collect()
            self.model.train()

def train_single_model(run_idx, image_indices_path, wandb_base_name, RESULTS_DIR, user_img_dir, p_user):
    run_dir = os.path.join(RESULTS_DIR, f"run_{run_idx:04d}")
    os.makedirs(run_dir, exist_ok=True)

    # Setup data and model
    dm = MixtureCelebADataModule(run_dir, image_indices_path, user_img_dir, p_user)
    model = DiffusionModel(run_dir)

    # Setup WandB
    wandb_run_name = f"{wandb_base_name}_run_{run_idx:04d}"
    wandb_logger = WandbLogger(
        project="celeba-diffusion-fixed",
        name=wandb_run_name,
        save_dir=run_dir
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=run_dir,
        filename="{epoch}",
        save_top_k=-1,
        every_n_epochs=10,
        monitor=None,
        save_weights_only=True
    )

    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS,
        logger=wandb_logger,
        log_every_n_steps=1,
        callbacks=[checkpoint_callback],
        accelerator="gpu",
        devices="auto",
        strategy="ddp",
        precision="16-mixed"
    )

    trainer.fit(model, dm)
    wandb.finish()

if __name__ == "__main__":
    wandb_base_name = f"celeba_diffusion_10k_celeba_samples_p_{P_USER_IMAGE}_old_men_lr_1e-4"
    print(f"WandB base name: {wandb_base_name}")
    RESULTS_DIR = os.path.join("/nobackup/users/scarv/data-attribution-diffusion/celeba_results", wandb_base_name)
    os.makedirs(RESULTS_DIR, exist_ok=True)

    image_indices_path = INDICES_PATH
    for run_idx in range(NUM_RUNS):
        train_single_model(run_idx, image_indices_path, wandb_base_name, RESULTS_DIR, PATH_TO_USER_IMG_DIR, P_USER_IMAGE)