# borrowed from official implementation of Golden Noise: https://github.com/xie-lab-ml/Golden-Noise-for-Diffusion-Models

import torch
import torch.nn as nn

from diffusers.models.normalization import AdaGroupNorm

from . import NoiseTransformer, SVDNoiseUnet


class NPNet(nn.Module):
    def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
        super(NPNet, self).__init__()

        assert model_id in ['SDXL', 'DreamShaper', 'DiT']

        self.model_id = model_id
        self.device = device
        self.pretrained_path = pretrained_path

        (
                self.unet_svd, 
                self.unet_embedding, 
                self.text_embedding, 
                self._alpha, 
                self._beta
            ) = self.get_model()

    def get_model(self):

        unet_embedding = NoiseTransformer(resolution=128).to(self.device).to(torch.float32)
        unet_svd = SVDNoiseUnet(resolution=128).to(self.device).to(torch.float32)

        if self.model_id == 'DiT':
            text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
        else:
            text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32) 

        
        if '.pth' in self.pretrained_path:
                gloden_unet = torch.load(self.pretrained_path)
                unet_svd.load_state_dict(gloden_unet["unet_svd"])
                unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
                text_embedding.load_state_dict(gloden_unet["embeeding"])
                _alpha = gloden_unet["alpha"]
                _beta = gloden_unet["beta"]

                print("Load Successfully!")

                return unet_svd, unet_embedding, text_embedding, _alpha, _beta
        
        else:
                assert ("No Pretrained Weights Found!")
        

    def forward(self, initial_noise, prompt_embeds):

        prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
        text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)

        encoder_hidden_states_svd = initial_noise
        encoder_hidden_states_embedding = initial_noise + text_emb

        golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())

        golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
                    2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding

        return golden_noise
