''' 
This script does conditional image generation on MNIST, using a diffusion model

This code is modified from,
https://github.com/cloneofsimo/minDiffusion

Diffusion model is based on DDPM,
https://arxiv.org/abs/2006.11239

The conditioning idea is taken from 'Classifier-Free Diffusion Guidance',
https://arxiv.org/abs/2207.12598

This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding',
https://arxiv.org/abs/2205.11487

'''
import random
# import time
import warnings
import torch.backends.cudnn as cudnn
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np

import wandb
from utils import AverageMeter, CustomDataset  # Import the AverageMeter class
from unet import ContextUnet, ddpm_schedules  # Import the ContextUnet class

import os
import argparse
from PIL import Image

from supervised_mnist import train_supervised, test
from train_mnist_lenet import LeNet5
from fid import calculate_fid, compute_metrics
from utils import CustomDataset, CustomDataset_idx
from filter import filter_function

# DDP
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist
from fld.metrics.FLD import FLD
from fid import calculate_features
from fld.metrics.AuthPct import AuthPct
from fld.metrics.CTTest import CTTest
from fld.metrics.FID import FID
from fld.metrics.KID import KID
from fld.metrics.PrecisionRecall import PrecisionRecall

def parse_arguments():
    parser = argparse.ArgumentParser(description='PyTorch DDPM MNIST')
    parser.add_argument('--dataset', default='mnist', help='dataset setting')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N',
                        help='mini-batch size')
    parser.add_argument('--epochs', default=20, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--root_log',type=str, default='log')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--exp_str', default='0', type=str, help='number to indicate which experiment it is')
    parser.add_argument('--resume', '-r', action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--root_model', type=str, default='runs')
    parser.add_argument('--log_results', action='store_true',
                        help='To log results on wandb')
    parser.add_argument('--evaluate_only', action='store_true',
                        help='To log results on wandb')
    parser.add_argument('--filter_only', action='store_true',
                        help='To log results on wandb')
    parser.add_argument('--save_model', action='store_true',
                        help='To log results on wandb')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('-p', '--print-freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 100)')
    parser.add_argument('-T','--diffusion-steps', default=500, type=int,
                        help='Number of diffusion steps')
    parser.add_argument('-G','--num_generations', default=10, type=int,
                        help='Number of diffusion steps')
    parser.add_argument('--start_gen', default=-1, type=int,
                        help='Number of diffusion steps')
    parser.add_argument('--feature_dim', default=256, type=int,
                        help='Feature Dim')
    parser.add_argument('--num_sampled_images', default=60000, type=int,
                        help='Number of sampled images')
    parser.add_argument('--sample_batch_size', default=200, type=int,
                        help='Batch Size for sampling')
    parser.add_argument('--seed', default=None, type=int,
                        help='seed for initializing training. ')
    parser.add_argument("--start_timestep", type=int, default=100)
    parser.add_argument("--end_timestep", type=int, default=200)
    parser.add_argument("--num_timesteps", type=int, default=5)
    parser.add_argument("--filter_type", type=str, default="recon_loss")
    parser.add_argument("--model_dir", type=str, default="")

    return parser.parse_args()

class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super().__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()
    
    def pred_x0_given_eps(self, x_t, eps, t):
        """
        Predict x_0 given x_t and eps
        """
        return (x_t - eps * self.sqrtmab[t,  None, None, None]) / self.sqrtab[t, None, None, None]

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """

        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
        
        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample_from_forward_process(self, x, c, t):
        noise = torch.randn_like(x)
        x_t = (
            self.sqrtab[t, None, None, None] * x
            + self.sqrtmab[t, None, None, None] * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.
        # I am not dropping out context here
        context_mask = torch.bernoulli(torch.zeros_like(c).float()).to(self.device)
        return x_t, noise, context_mask

    def sample(self, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        c_i = c_i.repeat(int(n_sample/c_i.shape[0]))

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        for i in range(self.n_T, 0, -1):
            # print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
        
        return x_i

    def conditional_sample(self, c_i, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        for i in range(self.n_T, 0, -1):
            # print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
        
        return x_i

def train_epoch(ddpm, dataloader, optim, device, epoch, args):
    if args.local_rank==0:
        print(f'epoch {epoch}')
    ddpm.train()

    loss_ema = None
    loss_meter = AverageMeter()

    for images, labels in dataloader:
        optim.zero_grad()
        images = 2*images - 1
        images = images.to(device)
        labels = labels.to(device)
        loss = ddpm(images, labels)
        loss.backward()
        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
        loss_meter.update(loss.item())
        optim.step()

    return loss_meter.avg

# Define the sample function
def sample_parallel(ddpm, num_samples, device, ws_test=0.0, args=None):
    samples, labels, N = [], [], 0
    num_processes, group = torch.distributed.get_world_size(), torch.distributed.group.WORLD
    n_sample_batch = args.sample_batch_size
    with torch.no_grad():
        ddpm.eval()
        while N < num_samples:
            c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
            c_i = c_i.repeat(int(n_sample_batch/c_i.shape[0]))
            x_gen = ddpm.conditional_sample(c_i,  n_sample_batch, args.image_shape, device, guide_w=ws_test)#.permute(0,2,3,1)
            samples_list = [torch.zeros_like(x_gen) for _ in range(num_processes)]
            labels_list = [torch.zeros_like(c_i) for _ in range(num_processes)]
            dist.all_gather(labels_list, c_i, group)
            labels.append(torch.cat(labels_list).detach().cpu())
            dist.all_gather(samples_list, x_gen, group)
            samples.append(torch.cat(samples_list).detach().cpu())
            N += len(c_i) * num_processes
            if args.local_rank==0 and N%10000==0:
                print(f"Generated {N} samples")
        samples = torch.cat(samples, dim=0)
        samples = ((samples + 1) * 127.5).clamp(0, 255).to(torch.uint8).permute(0,2,3,1).cpu().numpy()[:num_samples]
        return samples, np.concatenate(labels)[:num_samples]



def evaluate_mnist(X, Y, gen, args, filtered=False):
    device = args.device
    if args.dataset=="mnist":
        tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
        original_train_dataset = MNIST("./data", train=True, download=True, transform=tf)
        # original_train_dataset = torch.utils.data.Subset(original_train_dataset, range(10000))
        test_dataset = MNIST('data', train=False,
                       transform=tf)
        generated_train_dataset = CustomDataset(X, Y, tf)
    else:
        raise ValueError("Dataset not supported")
    # train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
    original_train_dataloader = DataLoader(original_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    generated_mnist_dataloader = DataLoader(generated_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    #generated_mnist_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    # Get test accuracy
    # test_acc = test_mnist(ddpm, dataloader, device)
    eval_model = LeNet5().to(args.device)
    ckpt = torch.load("mnist_lenet5.pth")
    eval_model.load_state_dict(ckpt)
    eval_model.eval()
    print("Checkpoint Loaded")
    sanity_check_acc = test(eval_model, device, test_dataloader)
    print(f"Sanity Check Accuracy: {sanity_check_acc}")

    eval_model_acc_on_generated_samples = test(eval_model, device, generated_mnist_dataloader)
    print(f"Accuracy of the Generated MNIST dataloader on a pretrained LeNet: {eval_model_acc_on_generated_samples}")
    if args.log_results:
        if filtered:
            wandb.log({'acc-gen_mnist_lenet_filtered':eval_model_acc_on_generated_samples, "gen":gen})
        else:
            wandb.log({'acc-gen_mnist_lenet':eval_model_acc_on_generated_samples, "gen":gen})

    fid, _, _ = calculate_fid(original_train_dataloader,
                 generated_mnist_dataloader,
                  eval_model,
                  args)
    # pr, recall, density, coverage = compute_metrics(original_train_dataloader, generated_mnist_dataloader, eval_model, args)
    # print("PR",pr)
    # print("Recall",recall)
    # print("Density",density)
    # print("Coverage",coverage)

    print(f"Filtered={filtered} FID",fid)
    if args.filter_only:
        return
    if args.log_results:
        if filtered:
            wandb.log({'fid_filtered':fid, "gen":gen})
        else:
            wandb.log({'fid':fid, "gen":gen})

    train_feat = torch.from_numpy(calculate_features(original_train_dataloader, eval_model, args.batch_size))
    test_feat = torch.from_numpy(calculate_features(test_dataloader, eval_model, args.batch_size))
    gen_feat = torch.from_numpy(calculate_features(generated_mnist_dataloader, eval_model, args.batch_size))
    
    # from my_fld import FLD_Mine
    # fld_val = FLD_Mine(gen_size=20000).compute_metric(train_feat, test_feat, gen_feat)
    # print(f"FLD: {fld_val:.3f}")
    # By default on 10k samples
    auth_pct = AuthPct().compute_metric(train_feat, test_feat, gen_feat)
    ct_test = CTTest().compute_metric(train_feat, test_feat, gen_feat)
    print(f"Auth PCT (10k samples): {auth_pct}")
    print(f"CT Test: {ct_test}")
    fid_2 = FID().compute_metric(train_feat, None, gen_feat)
    print(f"FID: {fid_2}")
    test_fid = FID(ref_feat = "test").compute_metric(None, test_feat, gen_feat)

    # train_fld = FLD(eval_feat="train", gen_size=20000).compute_metric(train_feat, test_feat, gen_feat)
    # test_fld = FLD(eval_feat="test").compute_metric(train_feat, test_feat, gen_feat)
    # print(f"Train FLD: {train_fld}")
    # print(f"Test FLD: {test_fld}")
    prec = PrecisionRecall(mode="Precision").compute_metric(train_feat, None, gen_feat) # Default precision
    rec  = PrecisionRecall(mode="Recall", num_neighbors=5).compute_metric(train_feat, None, gen_feat) # Recall with k=5
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    # Like FID, can get either Train or Test KID
    test_kid = KID(ref_feat="test")
    print(test_kid.ref_size)
    test_kid = KID(ref_feat="test", ref_size=len(gen_feat)).compute_metric(None, test_feat, gen_feat)
    print(f"Test KID: {test_kid}")
    train_kid = KID(ref_feat="train", ref_size=len(gen_feat)).compute_metric(train_feat, None, gen_feat)
    print(f"train_kid: {train_kid}")

    # from fld.sample_evaluation import sample_memorization_scores
    # memorization_scores = sample_memorization_scores(train_feat, test_feat, gen_feat)
    # print(f"Memorization Scores: {memorization_scores}")
    filtered_true = '' if not filtered else "_filtered"
    if args.log_results:
        wandb.log({'auth_pct'+filtered_true:auth_pct, 
                   "ct_test"+filtered_true:ct_test,
                    'precision'+filtered_true:prec,
                    'recall'+filtered_true:rec,
                    'fid_2'+filtered_true:fid_2,
                    'test_fid'+filtered_true:test_fid,
                    'test_kid'+filtered_true:test_kid,
                    'train_kid'+filtered_true:train_kid,
                      "gen":gen})
    if args.filter_only:
        return

    # Now for training
    model = LeNet5().to(device)
    # Define the loss function
    LR = 0.001
    generated_mnist_dataloader = DataLoader(generated_train_dataset, batch_size=128, shuffle=True, num_workers=args.workers)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    NUM_EPOCHS = 30

    for epoch in range(1, NUM_EPOCHS + 1):
        train_supervised(model, device, generated_mnist_dataloader, optimizer, epoch, args)
        acc = test(model, device, test_dataloader)
    torch.save(model.state_dict(), os.path.join(args.save_dir, f"gen_{gen}_trained_sup_model.pth"))
    print(f"Test Accuracy: {acc:.4f}")
    model.eval()
    train_acc = test(model, device, original_train_dataloader)
    print("Train Accuracy on Real Data", train_acc)
    if args.log_results:
        wandb.log({f'test-acc_trained_on_gen_mnist_{filtered}':acc, f"train-acc_trained_on_gen_mnist_{filtered}":train_acc, "gen":gen})
    
    # Now for C-Score Analysis
    mnist_tfds_order = np.load("mnist_tfds.npz")
    mnist_tfds_order_dataset = CustomDataset(mnist_tfds_order['images'], mnist_tfds_order['labels'], tf)

    c_score = np.load("data/mnist_cscores.npy")
    print("C-Score Shape: ",c_score.shape)
    top_k_ranges = [0.1, 0.3, 0.5, 0.6, 0.8, 0.9]
    for k in top_k_ranges:
        top_k_indices = np.argsort(c_score)[:int(k*len(c_score))]
        top_k_dataset = torch.utils.data.Subset(mnist_tfds_order_dataset, top_k_indices)
        top_k_dataloader = DataLoader(top_k_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        acc = test(model, device, top_k_dataloader)
        bottom_k_indices = np.argsort(c_score)[::-1][:int(k*len(c_score))]
        bottom_k_dataset = torch.utils.data.Subset(mnist_tfds_order_dataset, bottom_k_indices)
        bottom_k_dataloader = DataLoader(bottom_k_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
        bottom_k_acc = test(model, device, bottom_k_dataloader)

        print(f"Top {k*100}% C-Score Samples Accuracy: {acc}")
        print(f"Bottom {k*100}% C-Score Samples Accuracy: {bottom_k_acc}")
        if args.log_results:
            wandb.log({f'top_{k*100}_acc_c_score'+filtered_true:acc, f'bottom_{k*100}_acc_c_score'+filtered_true:bottom_k_acc,"gen":gen})
    

def train_mnist(args):

    # hardcoding these here
    n_T = args.diffusion_steps # 500
    device = args.device
    # print(args.device)
    n_feat = args.feature_dim # 128 ok, 256 better (but slower)
    save_model = args.save_model
    save_dir = f"runs/{args.store_name}"
    args.save_dir = save_dir
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    #num_generations = 2
    guide_w = 0.0
    if args.filter_only:
        nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=args.n_classes).to(device)
        nn_model = DDP(nn_model, device_ids=[args.local_rank], output_device=args.local_rank)
        # Load weights
        eval_gen = 0
        ckpt_path = f"{args.model_dir}/gen_{eval_gen}_ddpm_final.pth"
        nn_model_ckpt_path = f"{args.model_dir}/gen_{eval_gen}_nn_model_final.pth"
        nn_model.load_state_dict(torch.load(nn_model_ckpt_path))
        ddpm = DDPM(nn_model, betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1).to(device)
        ddpm.load_state_dict(torch.load(ckpt_path))

        data = np.load(f"{args.model_dir}/gen_{eval_gen}_generated_data_w_{guide_w}.npz")
        samples = data['X_all']
        labels = data['Y_all']
        evaluate_mnist(samples, labels, eval_gen, args, filtered=False)
        custom_dataset_idx = CustomDataset_idx(samples, labels, transform=transforms.ToTensor())
        print("Length of Custom Set: ",len(custom_dataset_idx))
        sample_data_loader_no_shuffle = DataLoader(
            custom_dataset_idx,
            batch_size=400,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )
        # if args.local_rank==0:
        print("Filtering")
        all_vals, all_idx = filter_function(nn_model, ddpm, sample_data_loader_no_shuffle, args)

        print(len(all_vals), len(all_idx), all_idx[:10])
        new_labels = labels[all_idx]
        filtered_indices = []
        for class_idx in range(10):
            indices_of_class = np.where(new_labels == class_idx)[0]
            print(f"Class {class_idx}: {len(indices_of_class)}")
            vals_class = all_vals[indices_of_class]
            # Sort vals_class and find topk
            sorted_indices = vals_class.argsort()
            topk_indices = sorted_indices[:args.num_sampled_images//10]
            filtered_indices.extend(indices_of_class[topk_indices].tolist())
    
        filtered_indices = np.array(filtered_indices)
        print(len(filtered_indices))
        filtered_sampled_images = np.asarray(samples[filtered_indices])
        filtered_labels = labels[filtered_indices]
        print(filtered_sampled_images.shape, filtered_labels.shape)
        evaluate_mnist(filtered_sampled_images, filtered_labels, eval_gen, args, filtered=True)
        exit()

    if args.evaluate_only:
        print("Evaluating only...")
        for eval_gen in range(0, args.num_generations):
            try:
                data = np.load(f"{args.model_dir}/gen_{eval_gen}_generated_data_w_{guide_w}.npz")
                X = data['X']
                Y = data['Y']
                evaluate_mnist(X, Y, eval_gen, args, filtered=True)
            except:
                print(f"Error in evaluating generation {eval_gen}")
            try:
                X_all = data['X_all']
                Y_all = data['Y_all']
                evaluate_mnist(X_all, Y_all, eval_gen, args, filtered=False)
            except:
                print("No X_all and Y_all found")
        # eval_gen = 1
        # data = np.load(f"{args.model_dir}/gen_{eval_gen}_generated_data_w_{guide_w}.npz")
        # X = data['X']
        # Y = data['Y']
        # evaluate_mnist(X, Y, eval_gen, args)
        return

    for gen in range(args.start_gen+1, args.num_generations):
        if args.local_rank==0:
            print(f'generation {gen}')
        nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=args.n_classes).to(device)
        nn_model = DDP(nn_model, device_ids=[args.local_rank], output_device=args.local_rank)
        ddpm = DDPM(nn_model, betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1).to(device)

        tf = transforms.Compose([transforms.ToTensor()]) 

        if gen==0:
            dataset = MNIST("./data", train=True, download=True, transform=tf)
        else:
            guide_w = 0.0
            data = np.load(f"{save_dir}/gen_{gen-1}_generated_data_w_{guide_w}.npz")
            X = data['X']
            Y = data['Y']
            dataset = CustomDataset(X, Y, tf)
        
        train_sampler = DistributedSampler(dataset)
        dataloader = DataLoader(dataset, batch_size=args.batch_size//args.gpu_count, shuffle=False, sampler=train_sampler, pin_memory=True, num_workers=args.workers)
        optim = torch.optim.Adam(ddpm.parameters(), lr=args.lr)


        for epoch in range(args.epochs):

            optim.param_groups[0]['lr'] = args.lr * (1 - epoch / args.epochs)

            avg_loss = train_epoch(ddpm, dataloader, optim, device, epoch, args)

            if args.local_rank==0:
                print(f"Average loss: {avg_loss:.4f}")
                if args.log_results:
                    wandb.log({"train_loss": avg_loss, "epoch":epoch})

            if (epoch+1) % args.log_interval == 0 and args.local_rank==0:
                # continue
                ddpm.eval()
                with torch.no_grad():
                    n_sample = 10 * args.n_classes
                    for _, w in enumerate(ws_test):
                        x_gen = ddpm.sample(n_sample, (1, 28, 28), device, guide_w=w)
                        grid = make_grid(((x_gen + 1)/2).clamp(0.0, 1.0), nrow=10)
                        image_path = f"{save_dir}/gen_{gen}_image_ep{epoch}_w{w}.png"
                        save_image(grid, image_path)
                        # Log images to WandB
                        if args.log_results:
                            wandb.log({f"gen_{gen}_image_w{w}": [wandb.Image(image_path, caption=f"Epoch{epoch}-{w}")], "epoch":epoch})
                        print('saved image at ' + save_dir + f"/image_ep{epoch}_w{w}.png")

            # optionally save model
            if save_model and (epoch % args.log_interval==0 or epoch==int(args.epochs)-1) and args.local_rank==0:
                torch.save(ddpm.state_dict(), save_dir + f"gen_{gen}_model_{epoch}.pth")
                print('saved model at ' + save_dir + f"gen_{gen}_model_{epoch}.pth")
        
        if args.local_rank==0:
            torch.save(ddpm.state_dict(), os.path.join(save_dir, f"gen_{gen}_ddpm_final.pth"))
            torch.save(nn_model.state_dict(), os.path.join(save_dir, f"gen_{gen}_nn_model_final.pth"))
            print('saved model at ' + save_dir + f"/gen_{gen}_nn_model_final.pth")

        # Now sample 60k images from the model and save them
        # Probably need to batch this to avoid memory issues
        ddpm.eval()
        nn_model.eval()
        # with torch.no_grad():
        if args.epochs<=5:
            extra = 200
        else:
            extra = 20000
        samples, labels = sample_parallel(ddpm, args.num_sampled_images+extra, device, ws_test=0.0, args=args)
        custom_dataset_idx = CustomDataset_idx(samples, labels, transform=transforms.ToTensor())
        print("Length of Custom Set: ",len(custom_dataset_idx))
        sampler = DistributedSampler(custom_dataset_idx, shuffle=False) if ngpus > 1 else None
        sample_data_loader_no_shuffle = DataLoader(
            custom_dataset_idx,
            batch_size=400,
            shuffle=False,
            # sampler=sampler,
            num_workers=4,
            pin_memory=True,
        )
        if args.local_rank==0:
            print("Filtering")
        if args.filter_type!='random' and args.filter_type!='label':
            if args.local_rank==0:
                all_vals, all_idx = filter_function(nn_model, ddpm, sample_data_loader_no_shuffle, args)
        # Gather all_vals torch.gather ddp
        
        if args.local_rank==0:
            if args.filter_type=='random':
                #random_indices = np.random.choice(len(samples), args.num_sampled_images, replace=False)
                filtered_indices = []
                for label in np.unique(labels):
                    # Find indices of all samples with the current label
                    indices_of_label = np.where(labels == label)[0]

                    # Randomly select k indices from these
                    selected_indices = np.random.choice(indices_of_label, args.num_sampled_images//10, replace=False)

                    # Append selected indices to the list
                    filtered_indices.extend(selected_indices.tolist())
                filtered_indices = np.array(filtered_indices)
                filtered_sampled_images = np.asarray(samples[filtered_indices])
                filtered_labels = np.asarray(labels[filtered_indices])
                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_indices_epoch_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), filtered_indices)
            elif args.filter_type=='likelihood':
                # sorted_indices = all_vals.argsort()
                new_labels = labels[all_idx]
                dtype = [('values', float), ('indices', int), ('labels', int)]
                structured_array = np.array(list(zip(all_vals, all_idx, new_labels)), dtype=dtype)
                sorted_array = np.sort(structured_array, order=['values', 'labels'])
                all_top_k_indices = np.array([], dtype=int)
                unique_labels = np.unique(labels)
                for label in unique_labels:
                    # Here, we filter by label and then take the last k elements since we're interested in the highest values
                    class_elements = sorted_array[sorted_array['labels'] == label][:args.num_sampled_images//10]
                    all_top_k_indices = np.concatenate((all_top_k_indices, class_elements['indices']))
                # filtered_indices = all_idx[sorted_indices[:args.num_sampled_images]]
                filtered_indices = all_top_k_indices
                print(len(filtered_indices))
                filtered_sampled_images = np.asarray(samples[filtered_indices])
                filtered_labels = labels[filtered_indices]
                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_values_epoch_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), all_vals)
                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_indices_epoch_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), all_idx)
                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_filtered_indices_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), filtered_indices)
            elif args.filter_type=='recon_loss' or args.filter_type=='variance_x0':
                # sorted_indices = all_vals.argsort()
                # new_labels = labels[all_idx]
                # dtype = [('values', float), ('indices', int), ('labels', int)]
                # structured_array = np.array(list(zip(all_vals, all_idx, new_labels)), dtype=dtype)
                # sorted_array = np.sort(structured_array, order=['labels', 'values'])
                # # sorted_array = np.sort(structured_array, order=['labels', 'values'])
                # all_top_k_indices = np.array([], dtype=int)
                # unique_labels = np.unique(labels)
                # for label in unique_labels:
                #     # Here, we filter by label and then take the last k elements since we're interested in the highest values
                #     class_elements = sorted_array[sorted_array['labels'] == label][:args.num_sampled_images//10]
                #     all_top_k_indices = np.concatenate((all_top_k_indices, class_elements['indices']))
                # # filtered_indices = all_idx[sorted_indices[:args.num_sampled_images]]
                # filtered_indices = all_top_k_indices
                # print(len(filtered_indices))
                # filtered_sampled_images = np.asarray(samples[filtered_indices])
                # filtered_labels = labels[filtered_indices]
                # all_vals, all_idx = filter_function(nn_model, ddpm, sample_data_loader_no_shuffle, args)
                # print(len(all_vals), len(all_idx), all_idx[:10])
                print(len(all_vals), len(all_idx), all_idx[:10])
                new_labels = labels[all_idx]
                filtered_indices = []
                for class_idx in range(10):
                    indices_of_class = np.where(new_labels == class_idx)[0]
                    print(f"Class {class_idx}: {len(indices_of_class)}")
                    vals_class = all_vals[indices_of_class]
                    # Sort vals_class and find topk
                    sorted_indices = vals_class.argsort()
                    topk_indices = sorted_indices[:args.num_sampled_images//10]
                    filtered_indices.extend(indices_of_class[topk_indices].tolist())
            
                filtered_indices = np.array(filtered_indices)
                print(len(filtered_indices))
                filtered_sampled_images = np.asarray(samples[filtered_indices])
                filtered_labels = labels[filtered_indices]

                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_values_epoch_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), all_vals)
                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_indices_epoch_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), all_idx)
                np.savez(os.path.join(save_dir,f"gen_{gen}_{args.filter_type}_filtered_indices_{args.epochs}-timesteps_{args.diffusion_steps}.npz"), filtered_indices)
        if args.local_rank==0:
            print(f"Samples shape: {samples.shape}, Labels shape: {len(labels)}")
            np.savez(f"{save_dir}/gen_{gen}_generated_data_w_0.0.npz", X=filtered_sampled_images, X_all=samples, Y=filtered_labels, Y_all=labels)

        # with torch.no_grad():
        #     guide_w = 0.0
        #     n_sample = args.num_sampled_images
        #     total_num_samples = args.num_sampled_images#//args.n_classes
        #     #print(n_samples_per_class)
        #     data_dict = {'X': [], 'Y': []}
        #     n_sample_batch = 1000
        #     for i in range(total_num_samples//n_sample_batch):
        #          print(i)
        #          #c_i = (torch.ones(n_samples_per_class) * i).to(device)
        #          c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        #          c_i = c_i.repeat(int(n_sample_batch/c_i.shape[0]))
        #          x_gen = ddpm.conditional_sample(c_i,  n_sample_batch, (1, 28, 28), device, guide_w=0.0)
        #          x_gen = ((x_gen + 1) * 127.5).clamp(0, 255).to(torch.uint8).permute(0,2,3,1).squeeze().cpu().numpy()
        #          #print(x_gen.shape)
        #          data_dict['X'].append(x_gen)
        #          data_dict['Y'].extend(c_i.cpu().numpy())
        #     data_dict['X'] = np.vstack(data_dict['X']).astype(np.uint8)
        #     data_dict['Y'] = np.asarray(data_dict['Y'])
        #     #print(data_dict['X'].shape)
        #     np.savez(f"{save_dir}/gen_{gen}_generated_data_w_{guide_w}.npz", **data_dict)
        
        # Get Test Accuracy of MNIST with generated images
        # Load the generated images
        if args.local_rank==0:
            print(f"Loading generated images for evaluation")
            data = np.load(f"{save_dir}/gen_{gen}_generated_data_w_0.0.npz")
            X = data['X']
            X_all = data['X_all']
            Y = data['Y']
            Y_all = data['Y_all']
            #print(Y.shape, X.shape)
            try:
                evaluate_mnist(X, Y, gen, args, filtered=True)

                evaluate_mnist(X_all, Y_all, gen, args)
            except:
                print("Error in evaluation")


def ddp_setup(args):
    init_process_group(backend="nccl")
    args.local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    args.device = torch.device('cuda', args.local_rank)

def main(args):
    if not args.evaluate_only:
        ddp_setup(args)
    else:
        args.device = torch.device('cuda')
        args.local_rank = 0
    if args.log_results and args.local_rank==0:
        wandb.init(project="synthetic",
                                   entity="neurips", name=args.store_name)
        wandb.config.update(args)
        wandb.run.log_code(".")
    train_mnist(args)
    if not args.evaluate_only:
        destroy_process_group()

if __name__ == "__main__":
    args = parse_arguments()
    if args.dataset=="mnist":
        args.n_classes = 10
        args.image_shape = (1, 28, 28)
    else:
        raise ValueError("Dataset not supported")
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                        'This will turn on the CUDNN deterministic setting, '
                        'which can slow down your training considerably! '
                        'You may see unexpected behavior when restarting '
                        'from checkpoints.')
    args.store_name = '_'.join([args.dataset, 'ddpm', 'T', str(args.diffusion_steps),'UNet', str(args.feature_dim), 'bs', str(args.batch_size), 'epochs', str(args.epochs), 'gen', str(args.num_generations), str(args.filter_type), str(args.start_timestep), str(args.end_timestep), 'seed', str(args.seed),args.exp_str])
    if args.evaluate_only:
        args.store_name = args.model_dir.split('/')[-1] + "_eval" + "_" + args.exp_str
    os.makedirs("runs/" + args.store_name, exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if device == 'cpu':
        raise ValueError("Cuda not available")
    ngpus = torch.cuda.device_count()
    args.gpu_count = ngpus
    main(args)
