import timeimport globimport loggingfrom tqdm import tqdmimport osimport numpy as npimport cv2from PIL import Imageimport torchfrom torch import nnimport torchvision.utils as tvufrom sklearn import svmimport pickleimport torch.optim as optimfrom torch.optim import Adamimport randomimport mathimport torch.nn.functional as Fimport scipy.stats as stasfrom scipy.stats import kstest, norm, multivariate_normal, truncnormimport torch.distributions as distfrom numpy import loadtxtimport matplotlib.pyplot as pltfrom sklearn.manifold import TSNEfrom torch.utils.data import DataLoaderfrom pytorch_msssim import ssimfrom models.ddpm.diffusion import DDPMfrom models.improved_ddpm.script_util import i_DDPMfrom models.vae.vae import Encoder, Decoder, VAEfrom utils.text_dic import SRC_TRG_TXT_DICfrom utils.diffusion_utils import get_beta_schedule, denoising_stepfrom datasets.data_utils import get_dataset, get_dataloaderfrom configs.paths_config import DATASET_PATHS, MODEL_PATHS, HYBRID_MODEL_PATHS, HYBRID_CONFIGfrom datasets.imagenet_dic import IMAGENET_DICfrom utils.align_utils import run_alignmentfrom utils.distance_utils import euclidean_distance, cosine_similarityfrom torch.nn.functional import normalizedef compute_radius(x):    x = torch.pow(x, 2)    r = torch.sum(x)    r = torch.sqrt(r)    return rdef slerp_interpolation(z1, z2, num_points=10):    point_a = z1.flatten()    point_b = z2.flatten()    norm_a = torch.norm(point_a)    norm_b = torch.norm(point_b)    unit_a = point_a / norm_a    unit_b = point_b / norm_b    cos_theta = torch.dot(unit_a, unit_b)    theta = torch.acos(cos_theta)    t_values = torch.linspace(1/(num_points+1), num_points/(num_points+1), num_points)    factors_a = torch.sin((1 - t_values) * theta) / torch.sin(theta)    factors_b = torch.sin(t_values * theta) / torch.sin(theta)    slerp_points = (factors_a[:, None] * unit_a) + (factors_b[:, None] * unit_b)    interpolated_magnitudes = (1 - t_values) * norm_a + t_values * norm_b    interpolated_points = slerp_points * interpolated_magnitudes[:, None]    interpolated_points = interpolated_points.view(num_points,3,256,256)    return interpolated_pointsdef sample_spherical_shell(radius, ndim, nsamples):    x = np.random.normal(size=(nsamples, ndim))    norm = np.linalg.norm(x, axis=1)    x *= radius / norm[:, None]    return xdef find_angle(xo_A, xo_B, xo_C):    ### Find the mean direction defined by AB and AC    AB = (xo_B - xo_A) / torch.norm(xo_B - xo_A)    AC = (xo_C - xo_A) / torch.norm(xo_C - xo_A)    dot_p = torch.dot(AB.view(-1), AC.view(-1))    AB_magnitude = torch.norm(AB)    AC_magnitude = torch.norm(AC)    angle = torch.acos( dot_p / (AB_magnitude * AC_magnitude))    angle_degrees = angle * 180 / torch.tensor(3.14159)    return angle_degreesdef noise_estimation_loss(model,                          x0: torch.Tensor,                          t: torch.LongTensor,                          e: torch.Tensor,                          b: torch.Tensor, keepdim=False):    a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()    _, output = model(x, t.float())        if output.shape[1] == 6: # this is for improved DDPMs        output, logvar_learned = torch.split(output, output.shape[1] // 2, dim=1)        if keepdim:            return (e - output).square().sum(dim=(1, 2, 3))        else:            return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)    else: # this is for DDPMs        if keepdim:            return (e - output[1]).square().sum(dim=(1, 2, 3))        else:            return (e - output[1]).square().sum(dim=(1, 2, 3)).mean(dim=0)class DiscoveryDiffusion(object):    def __init__(self, args, config, device=None):        self.args = args        self.config = config        mse_loss = nn.MSELoss()        if device is None:            device = torch.device(                "cuda") if torch.cuda.is_available() else torch.device("cpu")        self.device = device        self.model_var_type = config.model.var_type        betas = get_beta_schedule(            beta_start=config.diffusion.beta_start,            beta_end=config.diffusion.beta_end,            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps        )        self.betas = torch.from_numpy(betas).float().to(self.device)        self.num_timesteps = betas.shape[0]        self.center = torch.zeros((3, 256, 256)).float().to(self.device)        alphas = 1.0 - betas        self.alphas = torch.from_numpy(alphas).float().to(self.device)        alphas_cumprod = np.cumprod(alphas, axis=0)        self.alphas_cumprod = torch.from_numpy(alphas_cumprod).float().to(self.device)        variances = 1.0 - alphas_cumprod        self.variances = torch.from_numpy(variances).float().to(self.device)        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])        posterior_variance = betas * \                             (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)        if self.model_var_type == "fixedlarge":            self.logvar = np.log(np.append(posterior_variance[1], betas[1:]))        elif self.model_var_type == 'fixedsmall':            self.logvar = np.log(np.maximum(posterior_variance, 1e-20))        if self.args.edit_attr is None:            self.src_txts = self.args.src_txts            self.trg_txts = self.args.trg_txts        else:            self.src_txts = SRC_TRG_TXT_DIC[self.args.edit_attr][0]            self.trg_txts = SRC_TRG_TXT_DIC[self.args.edit_attr][1]    def unseen_reconstruct(self):        print(self.args.exp)        # ----------- Model -----------#        if self.config.data.dataset == "LSUN":            if self.config.data.category == "bedroom":                url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"            elif self.config.data.category == "church_outdoor":                url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"        elif self.config.data.dataset == "CelebA_HQ":            url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"        elif self.config.data.dataset == "AFHQ":            pass        else:            raise ValueError        if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:            model = DDPM(self.config)            if self.args.model_path:                init_ckpt = torch.load(self.args.model_path)            else:                init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)            learn_sigma = False            print("Original diffusion Model loaded.")        elif self.config.data.dataset in ["FFHQ", "AFHQ"]:            model = i_DDPM(self.config.data.dataset)            if self.args.model_path:                init_ckpt = torch.load(self.args.model_path)            else:                init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])            learn_sigma = True            print("Improved diffusion Model loaded.")        else:            print('Not implemented dataset')            raise ValueError        model.load_state_dict(init_ckpt)        model.to(self.device)        model = torch.nn.DataParallel(model)        model.eval()        # ----------- Precompute Latents -----------#        seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0        seq_inv = [int(s) for s in list(seq_inv)]        seq_inv_next = [-1] + list(seq_inv[:-1])        ### get an arbitrary img        im = './raw.jpg'        img = Image.open(im).convert("RGB")        img = img.resize((self.config.data.image_size, self.config.data.image_size), Image.ANTIALIAS)        img = np.array(img)/255        img = torch.from_numpy(img).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(dim=0).repeat(n, 1, 1, 1)        img = img.to(self.config.device)        x0 = (img - 0.5) * 2.        tvu.save_image((x0 + 1) * 0.5, "original.png")        with torch.no_grad():            #---------------- Invert Image to Latent in case of Deterministic Inversion process -------------------#            if self.args.deterministic_inv:                seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0                seq_inv = [int(s) for s in list(seq_inv)]                seq_inv_next = [-1] + list(seq_inv[:-1])                x = x0.clone()                with tqdm(total=len(seq_inv), desc=f"Inversion process ") as progress_bar:                    for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):                        t = (torch.ones(n) * i).to(self.device)                        t_prev = (torch.ones(n) * j).to(self.device)                        x, _ = denoising_step(x, t=t, t_next=t_prev, models=model,                                           logvars=self.logvar,                                           sampling_type='ddim',                                           b=self.betas,                                           eta=0.0,                                           learn_sigma=learn_sigma,                                           ratio=0,                                           )                        progress_bar.update(1)                    x_lat = x.clone()                    time_in_start = time.time()        with torch.no_grad():            with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:                for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):                    t = (torch.ones(n) * i).to(self.device)                    t_next = (torch.ones(n) * j).to(self.device)                    x_lat, _ = denoising_step(x_lat, t=t, t_next=t_next, models=model,                                       logvars=self.logvar,                                       sampling_type='ddim',                                       b=self.betas,                                       eta=0.0,                                       learn_sigma=learn_sigma,                                       )                    progress_bar.update(1)                x0 = x_lat.clone()                tvu.save_image((x0 + 1) * 0.5, "recons.png")                time_in_end = time.time()                print(f"Reconstruction for 1 image takes {time_in_end - time_in_start:.4f}s")    return    def finetune(self):        print(self.args.exp)        # ----------- Load the pre-trained Models -----------#        if self.config.data.base_model in ["AFHQ"]:            model = i_DDPM(self.config.data.base_model)            if self.config.model.model_path:                init_ckpt = torch.load(self.config.model.model_path)                learn_sigma = True                print("Improved diffusion Model loaded.")        elif self.config.data.base_model in ["CELEBA", "BEDROOM", "CHURCH"]:            model = DDPM(self.config)            sampler = DDPM(self.config)            if self.config.model.model_path:                init_ckpt = torch.load(self.config.model.model_path)                learn_sigma = False                print("Original diffusion Model loaded.")        else:            print('Not implemented dataset')            raise ValueError                    model.load_state_dict(init_ckpt)        model.to(self.device)        model = torch.nn.DataParallel(model)        model.train()        # ----------- Optimizer and Scheduler -----------#        print(f"Setting optimizer with lr={self.args.lr_finetune}")        optim_ft = torch.optim.Adam(model.parameters(), weight_decay=0, lr=self.args.lr_finetune)        init_opt_ckpt = optim_ft.state_dict()        scheduler_ft = torch.optim.lr_scheduler.StepLR(optim_ft, step_size=1, gamma=self.args.sch_gamma)        init_sch_ckpt = scheduler_ft.state_dict()        # ------- Load the tuning dataset ------------- #        train_data, _ = get_dataset(dataset_type = 'CELEBA', dataset_paths=DATASET_PATHS, config=self.config, custome=True)        train_loader = DataLoader(train_data, batch_size=4,             drop_last=True, shuffle=True, sampler=None,             num_workers=self.config.data.num_workers, pin_memory=True,)        # ------------ model training (tuning-based methods) ----------- #        start_epoch, step = 0, 0        for epoch in range(20):            data_start = time.time()            data_time = 0            for i, x in enumerate(train_loader):                time_in_start = time.time()                n = x.size(0)                data_time += time.time() - data_start                model.train()                step += 1                x = x.to(self.device)                e = torch.randn_like(x)                b = self.betas                # antithetic sampling                t = torch.randint(                    low=0, high=self.num_timesteps, size=(n // 2 + 1,)                ).to(self.device)                t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]                loss = noise_estimation_loss(model, x, t, e, b)                logging.info(                    f"step: {step}, loss: {loss.item()}, data time: {data_time / (i+1)}"                )                optim_ft.zero_grad()                loss.backward()        # ------------ sampling test (tuning-based methods) ----------- #        seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0        seq_inv = [int(s) for s in list(seq_inv)]        seq_inv_next = [-1] + list(seq_inv[:-1])        count = 0         model.eval()        while count < 5000:            x_lat = torch.randn(1, 3, 256, 256, device=self.device)                  with torch.no_grad():                with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:                    for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):                        n_samples = 1                         t = (torch.ones(n_samples) * i).to(self.device)                        t_next = (torch.ones(n_samples) * j).to(self.device)                        x_lat, _ = denoising_step(x_lat, t=t, t_next=t_next, models=model,                                           logvars=self.logvar,                                           sampling_type='ddim',                                           b=self.betas,                                           eta=0.0,                                           learn_sigma=learn_sigma,                                           )                        progress_bar.update(1)                    x0 = x_lat.clone()                    save_edit = "sampled_"+str(count)+".png"                    tvu.save_image((x0 + 1) * 0.5, os.path.join("vanilla_tuning/results", save_edit))                    count += 1        return            def inversion(self):        print(self.args.exp)        # ----------- Model -----------#        if self.config.data.dataset == "LSUN":            if self.config.data.category == "bedroom":                url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/bedroom.ckpt"            elif self.config.data.category == "church_outdoor":                url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/church_outdoor.ckpt"        elif self.config.data.dataset == "CelebA_HQ":            url = "https://image-editing-test-12345.s3-us-west-2.amazonaws.com/checkpoints/celeba_hq.ckpt"        elif self.config.data.dataset == "AFHQ":            pass        else:            raise ValueError        if self.config.data.dataset in ["CelebA_HQ", "LSUN"]:            model = DDPM(self.config)            if self.args.model_path:                init_ckpt = torch.load(self.args.model_path)            else:                init_ckpt = torch.hub.load_state_dict_from_url(url, map_location=self.device)            learn_sigma = False            print("Original diffusion Model loaded.")        elif self.config.data.dataset in ["FFHQ", "AFHQ"]:            model = i_DDPM(self.config.data.dataset)            if self.args.model_path:                init_ckpt = torch.load(self.args.model_path)            else:                init_ckpt = torch.load(MODEL_PATHS[self.config.data.dataset])            learn_sigma = True            print("Improved diffusion Model loaded.")        else:            print('Not implemented dataset')            raise ValueError        model.load_state_dict(init_ckpt)        model.to(self.device)        model = torch.nn.DataParallel(model)        model.eval()        # ----------- Precompute Latents -----------#        print("Prepare identity latent")        seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0        seq_inv = [int(s) for s in list(seq_inv)]        seq_inv_next = [-1] + list(seq_inv[:-1])        imgs_data_ood = './data/data_path/bedroom_data.txt'        imgs_path = open(imgs_data_ood, 'r')        ood_imgs = [line.rstrip() for line in open(imgs_data_ood)]        ood_x_lat = torch.empty(1000,3,256,256)        n = 1        for c, im in enumerate(ood_imgs):            im = im.replace("\n","")            img = Image.open(im).convert("RGB")            img = img.resize((self.config.data.image_size, self.config.data.image_size), Image.ANTIALIAS)            img = np.array(img)/255            img = torch.from_numpy(img).type(torch.FloatTensor).permute(2, 0, 1).unsqueeze(dim=0).repeat(n, 1, 1, 1)            img = img.to(self.config.device)            with torch.no_grad():                x = x0.clone()                with tqdm(total=len(seq_inv), desc=f"Inversion process ") as progress_bar:                    for it, (i, j) in enumerate(zip((seq_inv_next[1:]), (seq_inv[1:]))):                        t = (torch.ones(n) * i).to(self.device)                        t_prev = (torch.ones(n) * j).to(self.device)                        x, _ = denoising_step(x, t=t, t_next=t_prev, models=model,                                           logvars=self.logvar,                                           sampling_type='ddim',                                           b=self.betas,                                           eta=0,                                           learn_sigma=learn_sigma,                                           ratio=0,                                           )                        progress_bar.update(1)                    x_lat = x.clone()                    ood_x_lat[c,:,:,:] = x_lat.detach()                    # break                torch.save(ood_x_lat,'ood_inverted.pt')        return    def unseen_sample(self):        print(self.args.exp)        # ----------- Load the pre-trained Models -----------#        if self.config.data.base_model in ["AFHQ"]:            model = i_DDPM(self.config.data.base_model)            if self.config.model.model_path:                init_ckpt = torch.load(self.config.model.model_path)                learn_sigma = True                print("Improved diffusion Model loaded.")        elif self.config.data.base_model in ["CELEBA", "BEDROOM", "CHURCH"]:            model = DDPM(self.config)            if self.config.model.model_path:                init_ckpt = torch.load(self.config.model.model_path)                learn_sigma = False                print("Original diffusion Model loaded.", self.config.data.base_model)        else:            print('Not implemented dataset')            raise ValueError                    model.load_state_dict(init_ckpt)        model.to(self.device)        model = torch.nn.DataParallel(model)        model.eval()        seq_inv = np.linspace(0, 1, self.args.n_inv_step) * self.args.t_0        seq_inv = [int(s) for s in list(seq_inv)]        seq_inv_next = [-1] + list(seq_inv[:-1])         ### get an arbitrary img        n = 1        #### slerp interpolative sampling        print("---GeoSampming!")        ## Some important reference points        origin = torch.zeros((3,256,256)).type(torch.FloatTensor).to(self.device)        id_x = torch.load('./data/iddpm_dog_id_t1000.pt')        ood_x_church = torch.load('./data/iddpm_dog_ood_church_t1000.pt') # modified according to target domains        ood_x_bedroom = torch.load('./data/iddpm_dog_ood_bedroom_t1000.pt')        ood_x_human = torch.load('./data/iddpm_dog_ood_human_t1000.pt')        ood_x_cat = torch.load('./data/iddpm_dog_ood_cat_t1000.pt')        ood_x_galaxy = torch.load('./data/iddpm_dog_ood_galaxy_t1000.pt')        ood_x_radiation = torch.load('./data/iddpm_dog_ood_radiation_t1000.pt')        ood_x_turbulence = torch.load('./data/iddpm_dog_ood_turbulence_t1000.pt')        id_center = torch.mean(id_x, 0)        num_points = 10        inv_data = id_x.detach().numpy()        flattened_data = inv_data.reshape(inv_data.shape[0], -1)        # # Step 2: Calculate the mean and variance -> multivariance took too much time        mean_vector = np.mean(flattened_data, axis=0)        # variance_vector = np.var(flattened_data, axis=0)        variance_scalar = np.var(flattened_data)        def sample_from_gaussian_univariate(mean, variance, num_samples, dimension):            return np.random.normal(mean, np.sqrt(variance), (num_samples, dimension))        num_samples_to_generate = 1000  # Adjust as needed        dimension = flattened_data.shape[1]        samples_from_gaussian = sample_from_gaussian_univariate(mean_vector, variance_scalar, num_samples_to_generate, dimension)        samples_from_gaussian = torch.from_numpy(samples_from_gaussian).view(num_samples_to_generate, 3, 256, 256).type(torch.float)        # # get any two latent samples and do the interpolation, check t_m = 600-500, eta = 0 - 0.3        img_count = 0        while img_count <= 5000:            count_i = random.randint(0, 999)            count_j = random.randint(0, 999)              z1 = ood_x_bedroom[count_i,:,:,:]            z2 = samples_from_gaussian[count_j,:,:,:]            interp_seq = slerp_interpolation(z1, z2, num_points=20)            with torch.no_grad():                   for k in range(20):                    x_lat_unseen = interp_seq[k,:,:,:].unsqueeze(0).to(self.device)                    n=1                    distance_test = False                    origin_angle_test = False                    pair_angle_test = False                    random_integers = [random.randint(0, 999) for _ in range(10)]                    ### distance criteria                    mean_dis = 0                    for count, i in enumerate(random_integers):                            temp_dis = euclidean_distance(x_lat_unseen, inv_data_original[i,:].unsqueeze(0).to(self.device))                            mean_dis += temp_dis                        mean_dis /= 10                        if abs(mean_dis-611.4) <= 0.3:                            print("pass the distance optimizer:", opt_count, mean_dis)                            distance_test = True                    #### angle criteria 1                    origin_angle = 0                     for count, i in enumerate(random_integers):                        temp_ang = find_angle(origin, x_lat_unseen.squeeze(0), ood_x_radiation[i,:].to(self.device))                        origin_angle += temp_ang                        # print("check tmp_ang:", temp_ang)                    origin_angle /= 10                     print("check the origin_angle:", origin_angle)                               if abs(origin_angle-88.0) <= 0.1:                        print("pass the origin_angle optimizer:", origin_angle)                        origin_angle_test = True                    #### angle criteria 2                    pair_angle = 0                     for count, i in enumerate(random_integers):                        temp_ang = find_angle(x_lat_unseen.squeeze(0), ood_x_radiation[i,:].to(self.device), ood_x_radiation[i+1,:].to(self.device))                        pair_angle += temp_ang                        # print("check tmp_ang:", temp_ang)                    pair_angle /= 10                            if abs(pair_angle-60) <= 0.1:                        print("pass the pair_angle optimizer:", pair_angle)                        pair_angle_test = True                    if distance_test == origin_angle_test == pair_angle_test == True:                                       if opt_count == 0:                            opt_latents = x_lat_unseen                            opt_count += 1                        else:                            opt_latents = torch.cat((opt_latents,x_lat_unseen))                            opt_count += 1                        print("pass the geo optimizer!")                        x_lat = x_lat_unseen                        with tqdm(total=len(seq_inv), desc=f"Generative process") as progress_bar:                            for it, (i, j) in enumerate(zip(reversed((seq_inv)), reversed((seq_inv_next)))):                                t = (torch.ones(n) * i).to(self.device)                                t_next = (torch.ones(n) * j).to(self.device)                                x_lat, h_lat = denoising_step(x_lat, t=t, t_next=t_next, models=model,                                                   logvars=self.logvar,                                                   sampling_type='ddim',                                                   b=self.betas,                                                   eta=self.args.eta,                                                   learn_sigma=learn_sigma,                                                   )                                progress_bar.update(1)                            x0 = x_lat.clone()                            save_unseen = 'sampled_'+str(img_count)+'.png'                            tvu.save_image((x0 + 1) * 0.5, os.path.join('results/',save_unseen))                            img_count += 1                            print("Finish sampling:", img_count, save_unseen)        return