from util.logger import logger

from typing import Optional, Union, List

import torch

from util.torch_util import tsfm_to_1d_array


def prepare_timestep(
    pipeline, 
    
    timesteps: Union[List[int], torch.Tensor], 

    target_length: int, 

    timestep_list: Optional[Union[List[int], torch.Tensor]] = None, 
    timestep_idx_list: Optional[Union[List[int], torch.Tensor]] = None
) -> torch.Tensor:
    if (timestep_list is not None) and (timestep_idx_list is not None):
        logger(
            f"Both `timestep_list` and `timestep_idx_list` are provided, "
            f"`timestep_list` prioritizes. ", 

            log_type = "warning"
        )

    if timestep_list is not None:
        if not isinstance(timestep_list, torch.Tensor):
            timestep = torch.cat(timestep_list)
            
        timestep = timestep_list
    elif timestep_idx_list is not None:
        if not isinstance(timestep_idx_list, (list, torch.Tensor)):
            timestep_idx_list = [timestep_idx_list]

        if not isinstance(timestep_idx_list, torch.Tensor):
            timestep_idx_list = torch.tensor(timestep_idx_list)
            
        timestep = []

        for timestep_idx in timestep_idx_list:
            if timestep_idx < pipeline._num_timesteps:
                timestep.append(timesteps[timestep_idx])
            else:
                timestep.append(
                    torch.zeros_like(timesteps[0])
                )

            # goto `for timestep_idx`
            pass
        
        timestep = torch.stack(timestep)
    else:
        raise ValueError(
            f"Either `timestep_idx_list` or `timestep_idx` should be provided. "
        )

    # timestep.shape = (target_length, )
    timestep = tsfm_to_1d_array(
        array = timestep, 
        target_length = target_length, 

        dtype = timestep.dtype, 
        device = timestep.device
    )

    return timestep
