import os
import json
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Subset
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel, DDIMScheduler
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
import torch.distributed as dist
import wandb

# Configuration Constants
NUM_RUNS = 1
IMAGE_SIZE = 28
BATCH_SIZE = 256 # on 4 GPUs, this is an effective batch size of 1024
NUM_EPOCHS = 100
N_SAMPLES = 100000 # number of samples to select from each mixture component
PARENT_DATA_DIR = ""
MNIST_DIR = ""
SEED = 42  # Random seed for reproducibility
P_TMNIST = 0.0  # Probability of sampling a TMNIST image
PATH_TO_TMNIST = ""  # Path to TMNIST images
RESULTS_DIR = ""

# Set random seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)

# Mixture Dataset: samples TMNIST image with probability p_tmnist, else MNIST
class MixtureMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, mnist_dataset, tmnist_tensor, p_tmnist=0.1, transform=None):
        self.mnist_dataset = mnist_dataset
        self.tmnist_tensor = tmnist_tensor
        self.p_tmnist = p_tmnist
        self.transform = transform
        self.N_mnist = len(self.mnist_dataset)
        self.N_tmnist = self.tmnist_tensor.shape[0]

    def __len__(self):
        return self.N_mnist + self.N_tmnist

    def __getitem__(self, idx):
        if np.random.rand() < self.p_tmnist:
            # Randomly select a TMNIST image
            tmnist_idx = np.random.randint(0, self.N_tmnist)
            img = self.tmnist_tensor[tmnist_idx]
            # Ensure shape is (1, 28, 28)
            if img.dim() == 2:
                img = img.unsqueeze(0)
            elif img.shape[0] != 1:
                img = img.view(1, 28, 28)
            if self.transform:
                img = self.transform(img)
            return img, 0  # dummy label
        else:
            img, label = self.mnist_dataset[np.random.randint(0, self.N_mnist)]
            # Ensure shape is (1, 28, 28)
            if img.dim() == 2:
                img = img.unsqueeze(0)
            elif img.shape[0] != 1:
                img = img.view(1, 28, 28)
            if self.transform:
                img = self.transform(img)
            return img, label

class MixtureMNISTDataModule(LightningDataModule):
    def __init__(self, p_tmnist=0.1, path_to_tmnist=PATH_TO_TMNIST):
        super().__init__()
        self.p_tmnist = p_tmnist
        self.path_to_tmnist = path_to_tmnist

    def setup(self, stage=None):
        # Only download MNIST on rank 0
        if dist.is_available() and dist.is_initialized():
            if dist.get_rank() == 0:
                mnist = torchvision.datasets.MNIST(
                    root=PARENT_DATA_DIR,
                    train=True,
                    download=True,
                    transform=transforms.ToTensor()
                )
            dist.barrier()  # Synchronize all processes
            mnist = torchvision.datasets.MNIST(
                root=PARENT_DATA_DIR,
                train=True,
                download=False,
                transform=transforms.ToTensor()
            )
        else:
            mnist = torchvision.datasets.MNIST(
                root=PARENT_DATA_DIR,
                train=True,
                download=True,
                transform=transforms.ToTensor()
            )
        # Subsample MNIST to N_SAMPLES
        if len(mnist) > N_SAMPLES:
            mnist = torch.utils.data.Subset(mnist, np.random.choice(len(mnist), N_SAMPLES, replace=False))
        tmnist_tensor = torch.load(self.path_to_tmnist)  # (N, 1, 28, 28)
        # Subsample TMNIST to N_SAMPLES
        if tmnist_tensor.shape[0] > N_SAMPLES:
            tmnist_indices = torch.randperm(tmnist_tensor.shape[0])[:N_SAMPLES]
            tmnist_tensor = tmnist_tensor[tmnist_indices]
            # save tmnist indices
            torch.save(tmnist_indices, self.path_to_tmnist.replace(".pt", f"{N_SAMPLES}_indices.pt"))
        self.dataset = MixtureMNISTDataset(mnist, tmnist_tensor, self.p_tmnist)
        print(f"Mixture dataset size: {len(self.dataset)}")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

class DiffusionModel(LightningModule):
    def __init__(self, run_dir):
        super().__init__()
        self.run_dir = run_dir
        self.model = UNet2DModel(
            sample_size=IMAGE_SIZE,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(32, 64, 128),  # 3 blocks
            down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
            norm_num_groups=8
        )
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            beta_schedule="linear"
        )

    def forward(self, x, timesteps):
        return self.model(x, timesteps).sample

    def training_step(self, batch, batch_idx):
        images, _ = batch
        # images: (B, 1, 28, 28), already normalized to [-1, 1]
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, 
                                 (images.shape[0],), device=self.device).long()
        noisy_images = self.noise_scheduler.add_noise(images, noise, timesteps)
        noise_pred = self(noisy_images, timesteps)
        loss = F.mse_loss(noise_pred, noise)
        self.log("train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)

    def on_train_epoch_end(self):
        if self.trainer.is_global_zero and (self.current_epoch % 10 == 0 or self.current_epoch == NUM_EPOCHS - 1):
            self.model.eval()
            with torch.no_grad():
                noise = torch.randn((16, 1, IMAGE_SIZE, IMAGE_SIZE), device=self.device)
                ddim_scheduler = DDIMScheduler.from_config(self.noise_scheduler.config)
                ddim_scheduler.set_timesteps(50)
                for t in ddim_scheduler.timesteps:
                    noise_pred = self.model(noise, t).sample
                    noise = ddim_scheduler.step(noise_pred, t, noise).prev_sample
                samples = (noise / 2 + 0.5).clamp(0, 1)
                filename = os.path.join(self.run_dir, f"samples_epoch_{self.current_epoch:04d}.png")
                torchvision.utils.save_image(samples, filename)
                self.logger.experiment.log({
                    "samples": wandb.Image(filename),
                    "epoch": self.current_epoch
                })
            del samples, noise
            torch.cuda.empty_cache()
            import gc
            gc.collect()
            self.model.train()


def train_single_model(run_idx, wandb_base_name, RESULTS_DIR, p_tmnist):
    run_dir = os.path.join(RESULTS_DIR, f"run_{run_idx:04d}")
    os.makedirs(run_dir, exist_ok=True)
    dm = MixtureMNISTDataModule(p_tmnist=p_tmnist, path_to_tmnist=PATH_TO_TMNIST)
    model = DiffusionModel(run_dir)
    wandb_run_name = f"{wandb_base_name}_run_{run_idx:04d}"
    wandb_logger = WandbLogger(
        project="mnist-diffusion-mixture",
        name=wandb_run_name,
        save_dir=run_dir
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=run_dir,
        filename="{epoch}",
        save_top_k=1,
        every_n_epochs=10,
        monitor=None,
        save_weights_only=False
    )
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS,
        logger=wandb_logger,
        log_every_n_steps=1,
        callbacks=[checkpoint_callback],
        accelerator="gpu",
        devices="auto",
        strategy="ddp",
        precision="16-mixed"
    )
    trainer.fit(model, dm)
    wandb.finish()


if __name__ == "__main__":
    wandb_base_name = f"mnist_diffusion_mixture_p_tmnist_{P_TMNIST}_lr_1e-4_100_epochs_incl_optim_state"
    print(f"WandB base name: {wandb_base_name}")
    os.makedirs(RESULTS_DIR, exist_ok=True)
    for run_idx in range(NUM_RUNS):
        train_single_model(run_idx, wandb_base_name, RESULTS_DIR, P_TMNIST)