from util.logger import logger

from typing import Optional, Union, List, Tuple

from pathlib import Path

import numpy as np

import torch

import gc

from util.basic_util import (
    get_global_variable, 
    is_none, 
    get_true_value, 
    get_attr
)
from util.yaml_util import load_yaml

from .model import NPNet


def load_golden_noise_model(
    cfg_yaml_path: Optional[str] = None, 

    sd_type: Optional[str] = "SDXL",  # ["SDXL", "DreamShaper", "DiT"], 

    device: Optional[str] = "cpu"
) -> NPNet:
    cfg_yaml_path = cfg_yaml_path or "./config/model/golden_noise.yaml"
    cfg_dict = load_yaml(yaml_path = cfg_yaml_path)

    if sd_type not in ["SDXL", "DreamShaper", "DiT"]:
        raise ValueError(
            f"Unsupported `sd_type`, got `{sd_type}`. "
        )
    elif sd_type == "SDXL":
        ckpt_path = cfg_dict["golden_noise"]["npnet_model_ckpt_path_dict"]["sdxl"]
    elif sd_type == "DreamShaper":
        ckpt_path = cfg_dict["golden_noise"]["npnet_model_ckpt_path_dict"]["dream_shaper_xl_v2_turbo"]
    elif sd_type == "DiT":
        ckpt_path = cfg_dict["golden_noise"]["npnet_model_ckpt_path_dict"]["hunyuan_dit"]

    npnet = NPNet(
        model_id = sd_type, 

        pretrained_path = ckpt_path, 

        device = device
    )

    # ---------= [Clean Up] =---------
    del cfg_dict
    gc.collect()

    # `load_golden_noise_model()` done
    return npnet


def optimize_initial_noise(
    initial_noise_list: Union[torch.Tensor, List[torch.Tensor]], 

    npnet: NPNet, 

    prompt_list: Optional[Union[str, List[str]]] = None, 

    pipeline: Optional = None, 
    prompt_emb_list: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None
) -> torch.Tensor:
    if not isinstance(initial_noise_list, torch.Tensor):
        initial_noise_list = torch.stack(initial_noise_list)

    if (prompt_list is not None) and (not isinstance(prompt_list, list)):
        prompt_list = [prompt_list]

    prompt_emb_list_not_given = False

    if prompt_emb_list is None:
        prompt_emb_list_not_given = True

        (
            prompt_emb_list,

            negative_prompt_embeds,
            pooled_prompt_embeds,
            negative_pooled_prompt_embeds,
        ) = pipeline.encode_prompt(
            prompt = prompt_list,
            device = pipeline.device
        )

        # ---------= [Clean Up] =---------
        del negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
        torch.cuda.empty_cache()
        gc.collect()
    
    initial_noise_list = initial_noise_list.to(npnet.device)
    prompt_emb_list = prompt_emb_list.to(npnet.device)
    
    optimized_initial_noise_list = npnet(
        initial_noise = initial_noise_list, 

        prompt_embeds = prompt_emb_list
    )

    if prompt_emb_list_not_given:
        # ---------= [Clean Up] =---------
        del prompt_emb_list
        torch.cuda.empty_cache()
        gc.collect()

    # `optimize_initial_noise()` done
    return optimized_initial_noise_list
