import os
import json
import numpy as np
import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import Subset
from torchvision import transforms
from diffusers import DDPMScheduler, UNet2DModel, DDIMScheduler
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
import torch.distributed as dist
import wandb

# Configuration Constants
NUM_RUNS = 1
IMAGE_SIZE = 28
BATCH_SIZE = 256
NUM_EPOCHS = 110
SEED = 42
P_TMNIST = 1.0  # Probability of sampling a TMNIST image
PATH_TO_TMNIST = ""  # Path to TMNIST images
CHECKPOINT_PATH = ""  # Path to UNet2DModel checkpoint
PARENT_DATA_DIR = ""
RESULTS_DIR = ""

# Set random seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)

# Mixture Dataset: samples TMNIST image with probability p_tmnist, else MNIST
class MixtureMNISTDataset(torch.utils.data.Dataset):
    def __init__(self, mnist_dataset, tmnist_tensor, p_tmnist=0.1, transform=None):
        self.mnist_dataset = mnist_dataset
        self.tmnist_tensor = tmnist_tensor
        self.p_tmnist = p_tmnist
        self.transform = transform
        self.N_mnist = len(self.mnist_dataset)
        self.N_tmnist = self.tmnist_tensor.shape[0]

    def __len__(self):
        return self.N_mnist + self.N_tmnist

    def __getitem__(self, idx):
        if np.random.rand() < self.p_tmnist:
            tmnist_idx = np.random.randint(0, self.N_tmnist)
            img = self.tmnist_tensor[tmnist_idx]
            if img.dim() == 2:
                img = img.unsqueeze(0)
            elif img.shape[0] != 1:
                img = img.view(1, 28, 28)
            if self.transform:
                img = self.transform(img)
            return img, 0
        else:
            img, label = self.mnist_dataset[np.random.randint(0, self.N_mnist)]
            if img.dim() == 2:
                img = img.unsqueeze(0)
            elif img.shape[0] != 1:
                img = img.view(1, 28, 28)
            if self.transform:
                img = self.transform(img)
            return img, label

class MixtureMNISTDataModule(LightningDataModule):
    def __init__(self, p_tmnist=0.1, path_to_tmnist=PATH_TO_TMNIST):
        super().__init__()
        self.p_tmnist = p_tmnist
        self.path_to_tmnist = path_to_tmnist

    def setup(self, stage=None):
        # Only download MNIST on rank 0
        if dist.is_available() and dist.is_initialized():
            if dist.get_rank() == 0:
                mnist = torchvision.datasets.MNIST(
                    root=PARENT_DATA_DIR,
                    train=True,
                    download=True,
                    transform=transforms.ToTensor()
                )
            dist.barrier()
            mnist = torchvision.datasets.MNIST(
                root=PARENT_DATA_DIR,
                train=True,
                download=False,
                transform=transforms.ToTensor()
            )
        else:
            mnist = torchvision.datasets.MNIST(
                root=PARENT_DATA_DIR,
                train=True,
                download=True,
                transform=transforms.ToTensor()
            )
        tmnist_tensor = torch.load(self.path_to_tmnist)
        self.dataset = MixtureMNISTDataset(mnist, tmnist_tensor, self.p_tmnist)
        print(f"Mixture dataset size: {len(self.dataset)}")

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

class DiffusionModel(LightningModule):
    def __init__(self, run_dir):
        super().__init__()
        self.run_dir = run_dir
        self.model = UNet2DModel(
            sample_size=IMAGE_SIZE,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(32, 64, 128),
            down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
            up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
            norm_num_groups=8
        )
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            beta_schedule="linear"
        )
    def forward(self, x, timesteps):
        return self.model(x, timesteps).sample

    def training_step(self, batch, batch_idx):
        # print learning rate if it's the first training step
        if batch_idx == 0 and self.current_epoch == 100:
            optimizer = self.optimizers()
            lr = optimizer.param_groups[0]['lr']
            print(f"Learning Rate: {lr}")
        images, _ = batch
        noise = torch.randn_like(images)
        timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, 
                                 (images.shape[0],), device=self.device).long()
        noisy_images = self.noise_scheduler.add_noise(images, noise, timesteps)
        noise_pred = self.forward(noisy_images, timesteps)
        loss = F.mse_loss(noise_pred, noise)
        self.log("train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)

        return optimizer

    def on_train_epoch_end(self):
        if self.trainer.is_global_zero and (self.current_epoch % 10 == 0 or self.current_epoch == NUM_EPOCHS - 1):
            self.model.eval()
            with torch.no_grad():
                noise = torch.randn((16, 1, IMAGE_SIZE, IMAGE_SIZE), device=self.device)
                ddim_scheduler = DDIMScheduler.from_config(self.noise_scheduler.config)
                ddim_scheduler.set_timesteps(50)
                for t in ddim_scheduler.timesteps:
                    noise_pred = self.model(noise, t).sample
                    noise = ddim_scheduler.step(noise_pred, t, noise).prev_sample
                samples = (noise / 2 + 0.5).clamp(0, 1)
                filename = os.path.join(self.run_dir, f"samples_epoch_{self.current_epoch:04d}.png")
                torchvision.utils.save_image(samples, filename)
                self.logger.experiment.log({
                    "samples": wandb.Image(filename),
                    "epoch": self.current_epoch
                })
            del samples, noise
            torch.cuda.empty_cache()
            import gc
            gc.collect()
            self.model.train()

class UpdateLRCallback(pl.Callback):
    def __init__(self, new_lr):
        self.new_lr = new_lr
    def on_train_start(self, trainer, pl_module):
        optimizer = trainer.optimizers[0]
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.new_lr
        print(f"Learning Rate updated to: {self.new_lr}")


def finetune_model(run_idx, wandb_base_name, RESULTS_DIR, checkpoint_path, path_to_tmnist, p_tmnist):
    run_dir = os.path.join(RESULTS_DIR, f"run_{run_idx:04d}")
    os.makedirs(run_dir, exist_ok=True)
    dm = MixtureMNISTDataModule(p_tmnist=p_tmnist, path_to_tmnist=path_to_tmnist)
    model = DiffusionModel(run_dir)
    wandb_run_name = f"{wandb_base_name}_run_{run_idx:04d}"
    wandb_logger = WandbLogger(
        project="mnist-diffusion-mixture",
        name=wandb_run_name,
        save_dir=run_dir
    )
    checkpoint_callback = ModelCheckpoint(
        dirpath=run_dir,
        filename="{epoch}",
        save_top_k=-1,
        every_n_epochs=1,
        monitor=None,
        save_weights_only=True
    )
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS,
        logger=wandb_logger,
        log_every_n_steps=1,
        callbacks=[checkpoint_callback, UpdateLRCallback(1e-5)], # update LR to 1e-5
        accelerator="gpu",
        devices="auto",
        strategy="ddp",
        precision="16-mixed"
    )
    trainer.fit(model, dm, ckpt_path=checkpoint_path)
    wandb.finish()

if __name__ == "__main__":
    wandb_base_name = f"mnist_diffusion_finetune_mixture_p_tmnist_{P_TMNIST}_lr_1e-5_{NUM_EPOCHS}_epochs"
    print(f"WandB base name: {wandb_base_name}")
    os.makedirs(RESULTS_DIR, exist_ok=True)
    for run_idx in range(NUM_RUNS):
        finetune_model(run_idx, wandb_base_name, RESULTS_DIR, CHECKPOINT_PATH, PATH_TO_TMNIST, P_TMNIST)