from tools.helper import joint_vi_vae_encode_decode

class ScalesInjector():
    def __init__(self, args, vae, scale_schedule, tgt_h, tgt_w):
        self.inject_scales, self.inject_scales_bits = self._get_inject_scales(args.inject_scales, args.inject_scales_path, vae, scale_schedule, tgt_h, tgt_w)
        self.apply_spatial_patchify = args.apply_spatial_patchify
        
    
    def _get_inject_scales(self, inject_scales, img_path, vae, scale_schedule, tgt_h, tgt_w):
        match inject_scales:
            case 1: inject_scales = [2,3]
            case 0: inject_scales = []
            case _: raise(NotImplementedError) 
        encoding_bit_indices = None
        if inject_scales:
             _, _, encoding_bit_indices, *_ = joint_vi_vae_encode_decode(
                vae, img_path, scale_schedule, "cuda", tgt_h=tgt_h, tgt_w=tgt_w, apply_spatial_patchify=args.apply_spatial_patchify
             )

        return inject_scales, encoding_bit_indices