# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for OT-Bridge. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------

import os
import pickle
import torch

from guided_diffusion.script_util import create_model

from . import util
from .ckpt_util import (
    OTB_IMG256_UNCOND_PKL,
    OTB_IMG256_UNCOND_CKPT,
    OTB_IMG256_COND_PKL,
    OTB_IMG256_COND_CKPT,
)

from ipdb import set_trace as debug


class Image256Net(torch.nn.Module):
    def __init__(self, log, noise_levels, use_fp16=False, cond=False, 
                 use_seg=False, seg_channels=1, pretrained_adm=True, ckpt_dir="data/",
                 image_size=256):
        super(Image256Net, self).__init__()
        
        self.cond = cond
        self.use_seg = use_seg
        self.seg_channels = seg_channels
        self.image_size = image_size
        
        if image_size == 512:
            pretrained_adm = False
            log.info("[Net] Training 512x512 from scratch (no pretrained model available)")
        
        in_channels = 3  
        if cond:
            in_channels += 3 
        if use_seg:
            in_channels += seg_channels  
        
        ckpt_pkl = os.path.join(ckpt_dir, OTB_IMG256_COND_PKL if cond else OTB_IMG256_UNCOND_PKL)
        with open(ckpt_pkl, "rb") as f:
            kwargs = pickle.load(f)
        kwargs["use_fp16"] = use_fp16
        kwargs["in_channels"] = in_channels 
        kwargs["image_size"] = image_size   
        
        self.diffusion_model = create_model(**kwargs)
        log.info(f"[Net] Initialized network with {in_channels=}, seg={use_seg}!")
        
        if pretrained_adm and (cond or use_seg):
            ckpt_pt = os.path.join(ckpt_dir, OTB_IMG256_COND_CKPT if cond else OTB_IMG256_UNCOND_CKPT)
            try:
                out = torch.load(ckpt_pt, map_location="cpu", weights_only=True)
            except TypeError:
                out = torch.load(ckpt_pt, map_location="cpu")
            
            if in_channels > 3:
                self._extend_input_conv(out, in_channels)
            
            self.diffusion_model.load_state_dict(out, strict=False)
            log.info(f"[Net] Loaded pretrained adm {ckpt_pt=} with channel extension!")
        
        self.diffusion_model.eval()
        self.noise_levels = noise_levels
    
    def _extend_input_conv(self, state_dict, new_in_channels):

        input_conv_key = 'input_blocks.0.0.weight'
        if input_conv_key in state_dict:
            old_weight = state_dict[input_conv_key]  # [out_ch, 3, k, k]
            old_in_ch = old_weight.shape[1]
            
            new_weight = torch.zeros(
                old_weight.shape[0], new_in_channels, 
                old_weight.shape[2], old_weight.shape[3]
            )
            
            new_weight[:, :old_in_ch, :, :] = old_weight

            if new_in_channels > old_in_ch:
                avg_weight = old_weight.mean(dim=1, keepdim=True)
                for i in range(old_in_ch, new_in_channels):
                    new_weight[:, i:i+1, :, :] = avg_weight * 0.1
            
            state_dict[input_conv_key] = new_weight

    def forward(self, x, steps, cond=None, seg=None):
        t = self.noise_levels[steps].detach()
        assert t.dim()==1 and t.shape[0] == x.shape[0]
        
        inputs = [x]
        if self.cond and cond is not None:
            inputs.append(cond)
        if self.use_seg and seg is not None:
            inputs.append(seg)
        
        x = torch.cat(inputs, dim=1) if len(inputs) > 1 else x
        
        return self.diffusion_model(x, t)