
class Diffusion_Reconstruction(Diffusion):
    def __init__(self, args, config):
        super().__init__(args, config)
        
        self.arch = config.unlearn.arch
        self.model_path = config.unlearn.model_path

    def sample_visualization(self, model, unlearn_model, name, cond_scale):
        config = self.config
        total_n_samples = config.training.visualization_samples

        D_train_loader = get_dataset(self.args, config)
        D_train_iter = cycle(D_train_loader)
        x_orig, _ = next(D_train_iter)

        with torch.no_grad():
            c = unlearn_model(x_orig)

            n = c.size(0)
            z = torch.randn(
                n,
                config.data.channels,
                config.data.image_size,
                config.data.image_size,
                device=self.device,
            )

            x_gen = self.sample_image(z, model, c, cond_scale)
            x_gen = inverse_data_transform(config, x_gen)

            # (x_orig[i], x_gen[i])를 세로로 붙여 하나의 샘플로 구성
            paired_images = []
            for i in range(n):
                pair = torch.cat([x_orig[i], x_gen[i]], dim=1)  # (C, 2*H, W)
                paired_images.append(pair)

            paired_images = torch.stack(paired_images)  # (N, C, 2*H, W)

            grid = tvu.make_grid(
                paired_images,
                nrow=8,  # 한 줄에 8세트
                normalize=True,
                padding=2,
            )

            # 저장
            try:
                tvu.save_image(
                    grid, os.path.join(self.config.log_dir, f"sample-{name}.png")
                )
            except AttributeError:
                tvu.save_image(
                    grid, os.path.join(self.args.ckpt_folder, f"sample-{name}.png")
                )

    def train_recon(self):
        args, config = self.args, self.config
        D_train_loader = get_dataset(args, config)
        D_train_iter = cycle(D_train_loader)
        
        if config.unlearn.arch == 'resnet18':
            config.model.hidden_dim = 512
        elif config.unlearn.arch == 'vit':
            config.model.hidden_dim = 768

        model = Unlearn_Conditional_Model(config)

        unlearn_model = model_dict[self.arch](num_classes=config.data.n_classes)
        unlearn_model:torch.nn.Module

        if self.model_path:
            checkpoint = torch.load(self.model_path, map_location=self.device)
            checkpoint = checkpoint["state_dict"]
            unlearn_model.load_state_dict(checkpoint, strict=False)

        if config.unlearn.use_feature:
            if not hasattr(unlearn_model,'head'):
                unlearn_model.register_module('head', copy.deepcopy(unlearn_model.fc if config.unlearn.arch == 'resnet18' else unlearn_model.heads))
            if config.unlearn.arch =='resnet18':
                unlearn_model.register_module('fc', torch.nn.Identity())
            elif config.unlearn.arch == 'vit':
                unlearn_model.register_module('heads', torch.nn.Identity())

        for params in unlearn_model.parameters():
            params.requires_grad = False

        unlearn_model.eval()

        optimizer = get_optimizer(self.config, model.parameters())
        model.to(self.device)
        model = torch.nn.DataParallel(model)
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None
        
        model.train()
        
        start = time.time()
        for step in range(0, self.config.training.n_iters):

            model.train()
            x, _ = next(D_train_iter)
            with torch.no_grad():
                c = unlearn_model(x)
            n = x.size(0)
            x = x.to(self.device)
            x = data_transform(self.config, x)
            e = torch.randn_like(x)
            b = self.betas

            # antithetic sampling
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            loss = loss_registry_conditional[config.model.type](model, x, t, c, e, b)
            
            if (step+1) % self.config.training.log_freq  == 0:
                end = time.time()
                logging.info(
                    f"step: {step}, loss: {loss.item()}, time: {end-start}"
                )
                start = time.time()
                
            optimizer.zero_grad()
            loss.backward()

            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(model)

            if (step+1) % self.config.training.snapshot_freq == 0:
                states = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    step,
                ]
                if self.config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(
                    states,
                    os.path.join(self.config.ckpt_dir, "ckpt.pth"),
                )
                #torch.save(states, os.path.join(self.config.ckpt_dir, "ckpt_latest.pth"))

                test_model = ema_helper.ema_copy(model) if self.config.model.ema else copy.deepcopy(model)
                test_model.eval()
                self.sample_visualization(test_model,unlearn_model, step, args.cond_scale)
                del test_model
    
    def resume_recon(self):
        args, config = self.args, self.config
        D_train_loader = get_dataset(args, config)
        D_train_iter = cycle(D_train_loader)
        
        if config.unlearn.arch == 'resnet18':
            config.model.hidden_dim = 512
        elif config.unlearn.arch == 'vit':
            config.model.hidden_dim = 768

        model = Unlearn_Conditional_Model(config)

        unlearn_model = model_dict[self.arch](num_classes=config.data.n_classes)
        unlearn_model:torch.nn.Module

        if self.model_path:
            checkpoint = torch.load(self.model_path, map_location=self.device)
            checkpoint = checkpoint["state_dict"]
            unlearn_model.load_state_dict(checkpoint, strict=False)

        if config.unlearn.use_feature:
            if not hasattr(unlearn_model,'head'):
                unlearn_model.register_module('head', copy.deepcopy(unlearn_model.fc if config.unlearn.arch == 'resnet18' else unlearn_model.heads))
            if config.unlearn.arch =='resnet18':
                unlearn_model.register_module('fc', torch.nn.Identity())
            elif config.unlearn.arch == 'vit':
                unlearn_model.register_module('heads', torch.nn.Identity())

        for params in unlearn_model.parameters():
            params.requires_grad = False

        unlearn_model.eval()

        optimizer = get_optimizer(self.config, model.parameters())
        model.to(self.device)
        model = torch.nn.DataParallel(model)
        
        if self.config.model.ema:
            ema_helper = EMAHelper(mu=self.config.model.ema_rate)
            ema_helper.register(model)
        else:
            ema_helper = None
        
        model.train()
    
        if config.unlearn.resume:
            ckpt_path = os.path.join(config.unlearn.ckpt_dir, "ckpt.pth")
            logging.info(f"Resuming training from: {ckpt_path}")
            ckpt = torch.load(ckpt_path, map_location=self.device)
            model.load_state_dict(ckpt[0])
            optimizer.load_state_dict(ckpt[1])
            start_step = ckpt[2]
            if self.config.model.ema and len(ckpt) > 3:
                ema_helper.load_state_dict(ckpt[3])
        
        start = time.time()
        for step in range(start_step+1, start_step+self.config.training.n_iters):

            model.train()
            x, _ = next(D_train_iter)
            with torch.no_grad():
                c = unlearn_model(x)
            n = x.size(0)
            x = x.to(self.device)
            x = data_transform(self.config, x)
            e = torch.randn_like(x)
            b = self.betas

            # antithetic sampling
            t = torch.randint(
                low=0, high=self.num_timesteps, size=(n // 2 + 1,)
            ).to(self.device)
            t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
            loss = loss_registry_conditional[config.model.type](model, x, t, c, e, b)
            
            if (step+1) % self.config.training.log_freq  == 0:
                end = time.time()
                logging.info(
                    f"step: {step}, loss: {loss.item()}, time: {end-start}"
                )
                start = time.time()
                
            optimizer.zero_grad()
            loss.backward()

            try:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), config.optim.grad_clip
                )
            except Exception:
                pass
            optimizer.step()

            if self.config.model.ema:
                ema_helper.update(model)

            if (step+1) % self.config.training.snapshot_freq == 0:
                states = [
                    model.state_dict(),
                    optimizer.state_dict(),
                    step,
                ]
                if self.config.model.ema:
                    states.append(ema_helper.state_dict())

                torch.save(
                    states,
                    os.path.join(self.config.ckpt_dir, "ckpt.pth"),
                )
                #torch.save(states, os.path.join(self.config.ckpt_dir, "ckpt_latest.pth"))

                test_model = ema_helper.ema_copy(model) if self.config.model.ema else copy.deepcopy(model)
                test_model.eval()
                self.sample_visualization(test_model,unlearn_model, step, args.cond_scale)
                del test_model


    def forget_visualization(self, model, unlearn_model, name, forget_set, cond_scale):
        config = self.config

        D_train_iter = cycle(forget_set)
        x_orig, _ = next(D_train_iter)

        with torch.no_grad():
            c = unlearn_model(x_orig.to(self.device))

            n = c.size(0)
            z = torch.randn(
                n,
                config.data.channels,
                config.data.image_size,
                config.data.image_size,
                device=self.device,
            )

            x_gen = self.sample_image(z, model, c, cond_scale)
            x_gen = inverse_data_transform(config, x_gen)

            # (x_orig[i], x_gen[i])를 세로로 붙여 하나의 샘플로 구성
            paired_images = []
            for i in range(n):
                pair = torch.cat([x_orig[i], x_gen[i]], dim=1)  # (C, 2*H, W)
                paired_images.append(pair)

            paired_images = torch.stack(paired_images)  # (N, C, 2*H, W)

            grid = tvu.make_grid(
                paired_images,
                nrow=8,  # 한 줄에 8세트
                normalize=True,
                padding=2,
            )

            # 저장
            try:
                tvu.save_image(
                    grid, os.path.join(self.config.log_dir, f"sample-{name}.png")
                )
            except AttributeError:
                tvu.save_image(
                    grid, os.path.join(self.args.ckpt_folder, f"sample-{name}.png")
                )