from peft import LoraConfig
from omegaconf import OmegaConf

def create_lora_config(cfg):
    target_modules = list(cfg.lora.denoise_unet.target_modules)
    
    rank_pattern = {k: cfg.lora.denoise_unet.rank for k in target_modules}
    alpha_pattern = {k: cfg.lora.denoise_unet.alpha for k in target_modules}
    
    if OmegaConf.select(cfg, "lora.denoise_unet.is_conv") and cfg.lora.denoise_unet.is_conv:
        conv_modules = ['conv_in', 'conv1', 'conv', 'conv2', 'conv_out']
        target_modules += conv_modules
        rank_pattern.update({k: cfg.lora.denoise_unet.conv_rank for k in conv_modules})
        alpha_pattern.update({k: cfg.lora.denoise_unet.conv_alpha for k in conv_modules})
    
    return rank_pattern, alpha_pattern, target_modules

def load_denoise_unet_lora(cfg, denoising_unet=None):
    if OmegaConf.select(cfg, "lora.denoise_unet.rank"):
        rank_pattern, alpha_pattern, target_modules = create_lora_config(cfg)
        
        if cfg.unet_additional_kwargs.use_motion_module or cfg.lora.denoise_unet.only_mm: 
            if cfg.lora.denoise_unet.only_mm:
                target_modules_regex = ".*motion_modules.*(" + "|".join([f"({s})" for s in target_modules]) + ").*$"
            elif cfg.unet_additional_kwargs.use_motion_module:
                target_modules_regex = "^(?!.*motion_modules).*(" + "|".join([f"({s})" for s in target_modules]) + ").*$"
            denoise_unet_lora_config = LoraConfig(
                r=cfg.lora.denoise_unet.rank,
                lora_alpha=cfg.lora.denoise_unet.alpha,
                init_lora_weights="gaussian",
                lora_dropout=cfg.lora.denoise_unet.dropout,
                target_modules=target_modules_regex
            )
        else:
            denoise_unet_lora_config = LoraConfig(
                r=cfg.lora.denoise_unet.rank,
                lora_alpha=cfg.lora.denoise_unet.alpha,
                init_lora_weights="gaussian",
                lora_dropout=cfg.lora.denoise_unet.dropout,
                target_modules=target_modules,
                rank_pattern=rank_pattern,
                alpha_pattern=alpha_pattern
            )
        if denoising_unet is not None:
            denoising_unet.add_adapter(denoise_unet_lora_config)
        
        return denoise_unet_lora_config

if __name__ == "__main__":
    cfg = OmegaConf.load("dev/configs/train/stage_2_t2i_lora.yaml")
    denoise_unet_lora_config = load_denoise_unet_lora(cfg, denoising_unet=None)
    import pdb; pdb.set_trace()




