from tqdm import tqdm
import torch
import numpy as np

from diffusion_util import LinearModel, Diffusion

device = torch.device('cuda:0')

dm = LinearModel(
                z_dim=187, 
                time_dim=128,

                random_fourier_features=True,
                learned_sinusoidal_dim=32,

                if_cfg=True,
                cond_scale=3, 
                num_classes=2, 
                class_dim=128, 
                unit_dims=[512, 768, 1024, 768, 512])

dm.load_state_dict(torch.load("weight/model.pt"))
dm.to(device)

diffusion = Diffusion(
                        dm, 
                        dim = 187,
                        P_mean = -1.2,          
                        P_std = 1.2,  
                        sigma_data = 0.003,  

                        num_sample_steps = 32, 
                        sigma_min = 0.02,      
                        sigma_max = 80,            
                        rho = 7,                     
                        )

out = []
dm.eval()
num_1000 = 11
labels = [0 for _ in range(num_1000*1000 // 2)] + [1 for _ in range(num_1000*1000 // 2)]
labels = torch.tensor(labels).long().to(device)
for b in tqdm(range(num_1000), desc='Sampling...'):
    sampled_seq = diffusion.sample(batch_size=1000, labels=labels[b*1000:(b+1)*1000])
    out.append(sampled_seq)
out_seq = torch.cat(out)
out_seq = out_seq.detach().cpu().numpy()

np.save("EHRDiff", out_seq)


