# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Sample new images from a pre-trained DiT.
"""
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from utils import find_model, load_yaml_config

import argparse
from models import timeDiT
import numpy as np



parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default='config.yaml')
args = parser.parse_args()


#load config
config = load_yaml_config(f"./{args.config}")

# Setup PyTorch:
torch.manual_seed(config['sample']['seed'])
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"


# Load model:

model = timeDiT(**config['model'], **config['type']).to(device)
state_dict = find_model(config['sample']['model_path'])
model.load_state_dict(state_dict)
model.eval()  # important!

config['diffusion']['timestep_respacing'] = str(config['sample']['num_sampling_steps'])

diffusion = create_diffusion(**config['diffusion'])

# Labels to condition the model with (feel free to change):
class_labels = [0, 1, 2] * int(config['sample']['num_samples'] /3)


# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, config['model']['in_channels'], config['sample']['length'], device=device)
y = torch.tensor(class_labels, device=device)

# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([3] * n, device=device)
y = torch.cat([y, y_null], 0)

model_kwargs = dict(y=y, cfg_scale=config['sample']['scale'])


# Sample images:


samples = diffusion.p_sample_loop(
    model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
samples_np = samples.detach().cpu().numpy()

np.save(f'./generated_data/fake_data_{args.config}', samples_np)