# 测试预训练的antoencoder针对钛合金的效果如何
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
import yaml

from taming.modules.vqvae.quantize import VectorQuantizer as VectorQuantizer

from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

from ldm.util import instantiate_from_config

import argparse, os, sys, datetime, glob, importlib, csv


def get_parser(**parser_kwargs):
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ("yes", "true", "t", "y", "1"):
            return True
        elif v.lower() in ("no", "false", "f", "n", "0"):
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")

    parser = argparse.ArgumentParser(**parser_kwargs)
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="postfix for logdir",
    )
    parser.add_argument(
        "-r",
        "--resume",
        type=str,
        const=True,
        default="",
        nargs="?",
        help="resume from logdir or checkpoint in logdir",
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="paths to base configs. Loaded from left-to-right. "
             "Parameters can be overwritten or added with command-line options of the form `--key value`.",
        default=list(),
    )
    parser.add_argument(
        "-t",
        "--train",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="train",
    )
    parser.add_argument(
        "--no-test",
        type=str2bool,
        const=True,
        default=False,
        nargs="?",
        help="disable test",
    )
    parser.add_argument(
        "-p",
        "--project",
        help="name of new or path to existing project"
    )
    parser.add_argument(
        "-d",
        "--debug",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="enable post-mortem debugging",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=23,
        help="seed for seed_everything",
    )
    parser.add_argument(
        "-f",
        "--postfix",
        type=str,
        default="",
        help="post-postfix for default name",
    )
    parser.add_argument(
        "-l",
        "--logdir",
        type=str,
        default="logs",
        help="directory for logging dat shit",
    )
    parser.add_argument(
        "--scale_lr",
        type=str2bool,
        nargs="?",
        const=True,
        default=True,
        help="scale base-lr by ngpu * batch_size * n_accumulate",
    )
    return parser


class VQModel(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 n_embed,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 batch_resize_range=None,
                 scheduler_config=None,
                 lr_g_factor=1.0,
                 remap=None,
                 sane_index_shape=False, # tell vector quantizer to return indices as bhw
                 use_ema=False
                 ):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_embed = n_embed
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        # remap=remap,
                                        # sane_index_shape=sane_index_shape)
                                        )
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        self.batch_resize_range = batch_resize_range
        if self.batch_resize_range is not None:
            print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")

        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.scheduler_config = scheduler_config
        self.lr_g_factor = lr_g_factor

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
            print(f"Unexpected Keys: {unexpected}")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        return quant, emb_loss, info

    def encode_to_prequant(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        return h

    def decode(self, quant):
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant)
        return dec

    def decode_code(self, code_b):
        quant_b = self.quantize.embed_code(code_b)
        dec = self.decode(quant_b)
        return dec

    def forward(self, input, return_pred_indices=False):
        quant, diff, (_,_,ind) = self.encode(input)
        dec = self.decode(quant)
        if return_pred_indices:
            return dec, diff, ind
        return dec, diff

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        if self.batch_resize_range is not None:
            lower_size = self.batch_resize_range[0]
            upper_size = self.batch_resize_range[1]
            if self.global_step <= 4:
                # do the first few batches with max size to avoid later oom
                new_resize = upper_size
            else:
                new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
            if new_resize != x.shape[2]:
                x = F.interpolate(x, size=new_resize, mode="bicubic")
            x = x.detach()
        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        # https://github.com/pytorch/pytorch/issues/37142
        # try not to fool the heuristics
        x = self.get_input(batch, self.image_key)
        xrec, qloss, ind = self(x, return_pred_indices=True)

        if optimizer_idx == 0:
            # autoencode
            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train",
                                            predicted_indices=ind)

            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return aeloss

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return discloss

    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, suffix=""):
        x = self.get_input(batch, self.image_key)
        xrec, qloss, ind = self(x, return_pred_indices=True)
        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
                                        self.global_step,
                                        last_layer=self.get_last_layer(),
                                        split="val"+suffix,
                                        predicted_indices=ind
                                        )

        discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
                                            self.global_step,
                                            last_layer=self.get_last_layer(),
                                            split="val"+suffix,
                                            predicted_indices=ind
                                            )
        rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
        self.log(f"val{suffix}/rec_loss", rec_loss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"val{suffix}/aeloss", aeloss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        if version.parse(pl.__version__) >= version.parse('1.4.0'):
            del log_dict_ae[f"val{suffix}/rec_loss"]
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def configure_optimizers(self):
        lr_d = self.learning_rate
        lr_g = self.lr_g_factor*self.learning_rate
        print("lr_d", lr_d)
        print("lr_g", lr_g)
        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
                                  list(self.decoder.parameters())+
                                  list(self.quantize.parameters())+
                                  list(self.quant_conv.parameters())+
                                  list(self.post_quant_conv.parameters()),
                                  lr=lr_g, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr_d, betas=(0.5, 0.9))

        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                },
                {
                    'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                },
            ]
            return [opt_ae, opt_disc], scheduler
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if only_inputs:
            log["inputs"] = x
            return log
        xrec, _ = self(x)
        if x.shape[1] > 3:
            # colorize with random projection
            assert xrec.shape[1] > 3
            x = self.to_rgb(x)
            xrec = self.to_rgb(xrec)
        log["inputs"] = x
        log["reconstructions"] = xrec
        if plot_ema:
            with self.ema_scope():
                xrec_ema, _ = self(x)
                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
                log["reconstructions_ema"] = xrec_ema
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x


class VQModelInterface(VQModel):
    def __init__(self, embed_dim, *args, **kwargs):
        super().__init__(embed_dim=embed_dim, *args, **kwargs)
        self.embed_dim = embed_dim

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        return h

    def decode(self, h, force_not_quantize=False):
        # also go through quantization layer
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(h)
        else:
            quant = h
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant)
        return dec


class AutoencoderKL(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 ):
        super().__init__()
        self.image_key = image_key
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        assert ddconfig["double_z"]
        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        self.embed_dim = embed_dim
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)

        if optimizer_idx == 0:
            # train encoder+decoder+logvar
            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return aeloss

        if optimizer_idx == 1:
            # train the discriminator
            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
                                                last_layer=self.get_last_layer(), split="train")

            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
            return discloss

    def validation_step(self, batch, batch_idx):
        inputs = self.get_input(batch, self.image_key)
        reconstructions, posterior = self(inputs)
        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
                                        last_layer=self.get_last_layer(), split="val")

        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
                                            last_layer=self.get_last_layer(), split="val")

        self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def configure_optimizers(self):
        lr = self.learning_rate
        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
                                  list(self.decoder.parameters())+
                                  list(self.quant_conv.parameters())+
                                  list(self.post_quant_conv.parameters()),
                                  lr=lr, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr, betas=(0.5, 0.9))
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if not only_inputs:
            xrec, posterior = self(x)
            if x.shape[1] > 3:
                # colorize with random projection
                assert xrec.shape[1] > 3
                x = self.to_rgb(x)
                xrec = self.to_rgb(xrec)
            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
            log["reconstructions"] = xrec
        log["inputs"] = x
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x


class IdentityFirstStage(torch.nn.Module):
    def __init__(self, *args, vq_interface=False, **kwargs):
        self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff
        super().__init__()

    def encode(self, x, *args, **kwargs):
        return x

    def decode(self, x, *args, **kwargs):
        return x

    def quantize(self, x, *args, **kwargs):
        if self.vq_interface:
            return x, None, [None, None, None]
        return x

    def forward(self, x, *args, **kwargs):
        return x



if __name__ == '__main__':

    conf_f = open("~/Project/pytorch/stable-diffusion-tai/configs/latent-diffusion/cin-ldm-vq-f8.yaml", 'r+')
    conf_data = yaml.safe_load(conf_f)
    # print(conf_data)
    ddconfig = conf_data['model']['params']['first_stage_config']['params']['ddconfig']
    lossconfig = conf_data['model']['params']['first_stage_config']['params']['lossconfig']
    n_embed = conf_data['model']['params']['first_stage_config']['params']['n_embed']
    embed_dim = conf_data['model']['params']['first_stage_config']['params']['embed_dim']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    md_path = "~/Project/pytorch/stable-diffusion-tai/autoencoder_checkpoint/vq-f8-n256_model.ckpt"

    model = VQModelInterface(ddconfig=ddconfig, lossconfig=lossconfig, n_embed = n_embed, embed_dim=embed_dim).to(device)
    # model.load_state_dict(torch.load(md_path))
    print(model)
    # AutoencoderKL(
    #   (encoder): Encoder(
    #     (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #     (down): ModuleList(
    #       (0): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (downsample): Downsample(
    #           (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
    #         )
    #       )
    #       (1): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (downsample): Downsample(
    #           (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
    #         )
    #       )
    #       (2): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (nin_shortcut): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (downsample): Downsample(
    #           (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
    #         )
    #       )
    #       (3): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (downsample): Downsample(
    #           (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
    #         )
    #       )
    #       (4): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (nin_shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList(
    #           (0): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #         )
    #         (downsample): Downsample(
    #           (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2))
    #         )
    #       )
    #       (5): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList(
    #           (0): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #         )
    #       )
    #     )
    #     (mid): Module(
    #       (block_1): ResnetBlock(
    #         (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (dropout): Dropout(p=0.0, inplace=False)
    #         (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #       )
    #       (attn_1): AttnBlock(
    #         (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #         (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #         (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #         (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #       )
    #       (block_2): ResnetBlock(
    #         (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (dropout): Dropout(p=0.0, inplace=False)
    #         (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #       )
    #     )
    #     (norm_out): GroupNorm(32, 512, eps=1e-06, affine=True)
    #     (conv_out): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #   )
    #   (decoder): Decoder(
    #     (conv_in): Conv2d(64, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #     (mid): Module(
    #       (block_1): ResnetBlock(
    #         (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (dropout): Dropout(p=0.0, inplace=False)
    #         (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #       )
    #       (attn_1): AttnBlock(
    #         (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #         (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #         (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #         (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #       )
    #       (block_2): ResnetBlock(
    #         (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #         (dropout): Dropout(p=0.0, inplace=False)
    #         (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #       )
    #     )
    #     (up): ModuleList(
    #       (0): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (2): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #       )
    #       (1): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (nin_shortcut): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (2): ResnetBlock(
    #             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (upsample): Upsample(
    #           (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         )
    #       )
    #       (2): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (2): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (upsample): Upsample(
    #           (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         )
    #       )
    #       (3): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (nin_shortcut): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (2): ResnetBlock(
    #             (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList()
    #         (upsample): Upsample(
    #           (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         )
    #       )
    #       (4): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (2): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList(
    #           (0): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (2): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #         )
    #         (upsample): Upsample(
    #           (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         )
    #       )
    #       (5): Module(
    #         (block): ModuleList(
    #           (0): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (1): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #           (2): ResnetBlock(
    #             (norm1): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #             (norm2): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (dropout): Dropout(p=0.0, inplace=False)
    #             (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           )
    #         )
    #         (attn): ModuleList(
    #           (0): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (1): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #           (2): AttnBlock(
    #             (norm): GroupNorm(32, 512, eps=1e-06, affine=True)
    #             (q): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (k): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (v): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #             (proj_out): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    #           )
    #         )
    #         (upsample): Upsample(
    #           (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #         )
    #       )
    #     )
    #     (norm_out): GroupNorm(32, 128, eps=1e-06, affine=True)
    #     (conv_out): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #   )
    #   (loss): LPIPSWithDiscriminator(
    #     (perceptual_loss): LPIPS(
    #       (scaling_layer): ScalingLayer()
    #       (net): vgg16(
    #         (slice1): Sequential(
    #           (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (1): ReLU(inplace=True)
    #           (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (3): ReLU(inplace=True)
    #         )
    #         (slice2): Sequential(
    #           (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    #           (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (6): ReLU(inplace=True)
    #           (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (8): ReLU(inplace=True)
    #         )
    #         (slice3): Sequential(
    #           (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    #           (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (11): ReLU(inplace=True)
    #           (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (13): ReLU(inplace=True)
    #           (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (15): ReLU(inplace=True)
    #         )
    #         (slice4): Sequential(
    #           (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    #           (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (18): ReLU(inplace=True)
    #           (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (20): ReLU(inplace=True)
    #           (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (22): ReLU(inplace=True)
    #         )
    #         (slice5): Sequential(
    #           (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    #           (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (25): ReLU(inplace=True)
    #           (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (27): ReLU(inplace=True)
    #           (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    #           (29): ReLU(inplace=True)
    #         )
    #       )
    #       (lin0): NetLinLayer(
    #         (model): Sequential(
    #           (0): Dropout(p=0.5, inplace=False)
    #           (1): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    #         )
    #       )
    #       (lin1): NetLinLayer(
    #         (model): Sequential(
    #           (0): Dropout(p=0.5, inplace=False)
    #           (1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    #         )
    #       )
    #       (lin2): NetLinLayer(
    #         (model): Sequential(
    #           (0): Dropout(p=0.5, inplace=False)
    #           (1): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    #         )
    #       )
    #       (lin3): NetLinLayer(
    #         (model): Sequential(
    #           (0): Dropout(p=0.5, inplace=False)
    #           (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    #         )
    #       )
    #       (lin4): NetLinLayer(
    #         (model): Sequential(
    #           (0): Dropout(p=0.5, inplace=False)
    #           (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
    #         )
    #       )
    #     )
    #     (discriminator): NLayerDiscriminator(
    #       (main): Sequential(
    #         (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    #         (1): LeakyReLU(negative_slope=0.2, inplace=True)
    #         (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    #         (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    #         (4): LeakyReLU(negative_slope=0.2, inplace=True)
    #         (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    #         (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    #         (7): LeakyReLU(negative_slope=0.2, inplace=True)
    #         (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
    #         (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    #         (10): LeakyReLU(negative_slope=0.2, inplace=True)
    #         (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    #       )
    #     )
    #   )
    #   (quant_conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
    #   (post_quant_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    # )

    # 调用KL_8


