# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for I2SB. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

import os
import pickle
import torch

from inpainters.src_cddb.guided_diffusion_.script_util import create_model

from . import util
from .ckpt_util import (
    I2SB_IMG256_UNCOND_PKL,
    I2SB_IMG256_UNCOND_CKPT,
    I2SB_IMG256_COND_PKL,
    I2SB_IMG256_COND_CKPT,
    I2SB_CELEBA256_CKPT,
    I2SB_CELEBA256_PKL
)

import logging
log = logging.getLogger(__name__)


class Image256Net(torch.nn.Module):

    def __init__(self, noise_levels, cond, model_kwargs):
        super(Image256Net, self).__init__()

        # NOTE: for now we disable use_fp16
        self.diffusion_model = create_model(**model_kwargs)
        log.info(f"[Net] Initialized network! Size={util.count_parameters(self.diffusion_model)}!")
        self.diffusion_model.eval()
        self.cond = cond
        if cond: log.info(f"[Net] Using conditional version")
        self.register_buffer("noise_levels", noise_levels)

    def forward(self, x, steps, cond=None):
        t = self.noise_levels[steps].detach()
        assert t.dim()==1 and t.shape[0] == x.shape[0]

        x = torch.cat([x, cond], dim=1) if self.cond else x
        out = self.diffusion_model(x, t)
        return out


class Image32Net(torch.nn.Module):
    def __init__(self, log, noise_levels, use_fp16=False, ckpt_dir="data/", cond=False):
        super(Image32Net, self).__init__()

        # initialize model
        # MNIST model (1x32x32)
        self.diffusion_model = create_model(
            image_size=32,
            num_channels=32,
            num_res_blocks=2,
            attention_resolutions="8,4,2",
            num_heads=2,
            num_head_channels=16,
            in_channels=1,
            out_channels=1,
            use_fp16=use_fp16
        )
        log.info(f"[Net] Initialized network! Size={util.count_parameters(self.diffusion_model)}!")

        self.diffusion_model.eval()
        self.cond = cond
        self.noise_levels = noise_levels

    def forward(self, x, steps, cond=None):

        t = self.noise_levels[steps].detach()
        assert t.dim()==1 and t.shape[0] == x.shape[0]

        x = torch.cat([x, cond], dim=1) if self.cond else x
        return self.diffusion_model(x, t)

class VolumeNet(torch.nn.Module):
    def __init__(
            self, noise_levels, use_fp16=False, ckpt_dir="data/", cond=False,
            cube_size=64, 
            num_channels=32,
            num_res_blocks=3,
            attention_resolutions="16,8,3",
            num_heads=1,
            num_head_channels=32,
            use_new_attention_order=False,
            use_flash_attention=False,
            dropout=0.0
        ):
        super(VolumeNet, self).__init__()

        # import pdb; pdb.set_trace()


        # initialize model
        self.diffusion_model = create_model(
            image_size=cube_size,
            in_channels=2 if cond else 1,
            out_channels=1,
            dims=3,
            use_checkpoint=True, # Less memory usage
            num_channels=num_channels,
            num_res_blocks=num_res_blocks,
            attention_resolutions=attention_resolutions,
            num_heads=num_heads,
            num_head_channels=num_head_channels,
            use_fp16=use_fp16,
            use_new_attention_order=use_new_attention_order,
            use_flash_attention=use_flash_attention,
            dropout=dropout
        )

        self.diffusion_model.eval()
        self.cond = cond
        self.noise_levels = noise_levels

    def forward(self, x, steps, cond=None):

        t = self.noise_levels[steps].detach()
        assert t.dim()==1 and t.shape[0] == x.shape[0]

        x = torch.cat([x, cond], dim=1) if self.cond else x
        return self.diffusion_model(x, t)