person_specific:
    type: # PersonSpecificEncoder
        Transformer
    args:
        in_features: 58
        embed_dim: 512
        num_heads: 4
        num_layers: 4
        mlp_dim: 1024
        seq_len: 750
        proj_dim: 512
        proj_head: mlp
        drop_prob: 0.1
        max_len: 1000
        pos_encoding: absolute
        embed_layer: linear
    checkpoint_path:
        ./checkpoints/person_specific/exp_4/epoch_009_checkpoint.pth

audio_encoder:
    type:
        AudioEmbedder
    args:
        skip_norm: True # default skip the data normalization
    checkpoint_path:
        # default none

latent_embedder: # emotion
    type:
        AutoencoderRNN_VAE_v2
    args:
        emotion_dim: 25
        coeff_3dmm_dim: 58
        emb_dims: [128, 128]
        num_layers: 2
        hidden_dim: 512
        z_dim: 512
        rnn_type: 'gru'
        dropout: 0.0
        window_size: 50
        seq_len: 750
    checkpoint_path:
        ./checkpoints/latent_embedder/exp_2/epoch_499_checkpoint.pth

latent_3dmm_embedder: # only for 3dmm encodings
    type:
        AutoencoderRNN_VAE_v1
    args:
        _3dmm_dim: 58
        coeff_emotion_dim: 25
        emb_dims: [128, 128]
        num_layers: 2
        hidden_dim: 512
        z_dim: 512
        rnn_type: 'gru'
        dropout: 0.0
        window_size: 50
        seq_len: 750
    checkpoint_path:
        ./checkpoints/latent_3dmm_embedder/exp_1/epoch_899_checkpoint.pth

diffusion_prior:
    type:
        LatentMLPMatcher
    args:
        emb_preprocessing: normalize
        freeze_encoder: True
        audio_dim: 78
        window_size: 50
        token_len: 750
        _3dmm_dim: 58
        speaker_emb_dim: 512
        latent_dim: 512
        depth: 4
        num_time_layers: 2
        num_time_embeds: 1
        num_time_emb_channels: 64
        time_last_act: False
        use_learned_query: True
        s_audio_cond_drop_prob: 0.0 # speaker audio encodings
        s_latentemb_cond_drop_prob: 0.0 # speaker latent embedding
        s_3dmm_cond_drop_prob: 0.0 # speaker 3dmm encodings
        guidance_scale: 1.0
        dim_head: 64
        heads: 8
        ff_mult: 4
        norm_in: False
        norm_out: True
        attn_dropout: 0.0
        ff_dropout: 0.0
        final_proj: True
        normformer: False
        rotary_emb: True
    scheduler:
        noise_schedule: cosine
        timestep_spacing: leading # leading (default) or linspace or trailing
        num_train_timesteps: 1000 # TODO: set the diffusion steps
        num_inference_timesteps: 50 # TODO: set the diffusion steps
        predict: start_x # start_x or epsilon
        var_type: fixed_large
        rescale_timesteps: False
        noise_std: 1
        k: 1 # k appropriate generations

diffusion_decoder:
    type:
        TransformerDenoiser
    args:
        emb_preprocessing: normalize
        freeze_encoder: True
        window_size: 50
        token_len: 750
        encode_emotion: False # whether encode raw emotion to encodings
        encode_3dmm: False # whether encode raw 3dmm to encodings
        ablation_skip_connection: True
        nfeats: 25
        latent_dim: 512
        ff_size: 1024
        num_layers: 7
        num_heads: 4
        dropout: 0.1
        normalize_before: False
        activation: gelu
        flip_sin_to_cos: True
        return_intermediate_dec: False
        position_embedding: learned
        arch: trans_dec # trans_enc or trans_dec
        freq_shift: 0
        time_encoded_dim: 64
        s_audio_dim: 78 # encoded dim of speaker's audio feature
        # s_audio_scale: # scale of speaker's audio feature
        s_emotion_dim: 25 # encoded dim of speaker's emotion encodings
        l_embed_dim: 512
        s_embed_dim: 512
        personal_emb_dim: 512
        s_3dmm_dim: 58 # encoded dim of speaker 3dmm feature
        concat: concat_first # concat_first or concat_last
        guidance_scale: 7.5 # classifier-free guidance
        l_latent_embed_drop_prob: 1.0  # listener_latent_embed
        l_personal_embed_drop_prob: 1.0  # listener_personal_embed
        s_audio_enc_drop_prob: 0.2  # speaker_audio_encodings
        s_latent_embed_drop_prob: 0.2  # speaker_latent_embed
        s_3dmm_enc_drop_prob: 0.2  # speaker_3dmm_encodings
        s_emotion_enc_drop_prob: 1.0  # speaker_emotion_encodings
        past_l_emotion_drop_prob: 1.0  # past_listener_emotion
        use_past_frames: False # whether use past frames
    scheduler:
        noise_schedule: cosine # linear or cosine
        timestep_spacing: leading # leading (default) or linspace or trailing
        num_train_timesteps: 1000 # TODO: set the diffusion steps
        num_inference_timesteps: 50 # TODO: set the diffusion steps
        predict: start_x # start_x or epsilon
        var_type: fixed_large # or fixed_small
        rescale_timesteps: False
        noise_std: 1
        k: 1 # k appropriate generations

loss:
    type:
        DiffusionLoss
    args:
        losses_type: [MSELoss, MSELoss] # MSELossWithAct | MSELoss | L1Loss
        losses_multipliers: [0, 1] # TODO: loss weight
        losses_decoded: [False, True]
        k: 1
        temporal_loss_w: 0.0

trainer:
    seed: 1234
    start_epoch: 0
    epochs: 100
    model: LatentMatcher
    clip_grad: False
    resume: exp_6530/epoch_450_checkpoint.pth
    num_workers: 16
    log_dir: ./log/train_rewrite_weight
    tb_dir: ./tb_logs/train_rewrite_weight
    out_dir: ./results/train_rewrite_weight
    checkpoint_dir: ./checkpoints/train_main
    saving_checkpoint_dir: ./checkpoints/train_rewrite_weight
    save_period: 10
    val_period: 10

main_model: # our main model to modify the weights
    type:
        MainNetUnified
        # (default) MainNetUnified
        # or (deprecated) MainNet (for linear layer in feed-forward layer)
        # or (deprecated) MainNetCrossAttn (for mapping layer in cross-attn layer)
    args: # for our Modifier Network
        input_dim: 512
        latent_dim: 1024
        embed_dim: 512 # embed dim in Transformer decoder
        regularization: False # todo debug: we regularize generated weight deviation
        regular_w: 0.0 # regularization weight
        num_shared_layers: 2 # for shared encoder
        modified_layers:
            [diffusion_decoder.model.decoder.layers.4.multihead_attn,
             diffusion_decoder.model.decoder.layers.6.multihead_attn,
             diffusion_decoder.model.to_emotion_feat]
            # [diffusion_decoder.model.to_emotion_embed,]
            # [diffusion_decoder.model.decoder.layers.5.multihead_attn, diffusion_decoder.model.decoder.layers.6.multihead_attn]
            # [diffusion_decoder.model.decoder.layers.5.multihead_attn.out_proj, diffusion_decoder.model.decoder.layers.6.multihead_attn.out_proj]
            # [diffusion_decoder.model.decoder.layers.6.linear1, diffusion_decoder.model.decoder.layers.6.linear2]
        predict: shift # hyper-net outputs shift [w' = w + delta_w] or offset [w' = w * (1 + delta_w)]
        modify: all # we modify cross-attention's all [W_q & W_k & W_v] or kv [W_k & W_v]
        resume: exp_7351/epoch_050_checkpoint.pth
                # todo resume or load model_hypernet and optimizer_hypernet
    optimizer_hypernet:
        type:
            sgd
        args:
            lr: 0.001
            weight_decay: 1e-4
            momentum: 0.9
    optimizer_mainnet:
        type:
            sgd
        args:
            lr: 0.001
            weight_decay: 1e-4
            momentum: 0.9

dataset:
    batch_size: 1
    shuffle: True
    num_workers: 16
    dataset_path: ../phd_data_all/react_clean
    split: train
    num_person: 16
    num_sample: 4
    img_size: 256
    crop_size: 224
    clip_length: 750
    fps: 25
    load_audio: True
    load_video_s: False
    load_video_l: False
    load_emotion_l: True
    load_emotion_s: True
    load_3dmm_l: True
    load_3dmm_s: True
    load_ref: True
    k_appro: 1

validation_dataset:
    batch_size: 1
    shuffle: False
    num_workers: 16
    dataset_path: ../phd_data_all/react_clean
    split: val
    num_person: 16
    num_sample: 4
    img_size: 256
    crop_size: 224
    clip_length: 750
    fps: 25
    load_audio: True
    load_video_s: False
    load_video_l: False
    load_emotion_l: True
    load_emotion_s: True
    load_3dmm_l: True
    load_3dmm_s: True
    load_ref: True
    k_appro: 1

test_dataset:
    batch_size: 1
    shuffle: False
    num_workers: 16
    dataset_path: ../phd_data_all/react_clean
    split: test
    img_size: 256
    crop_size: 224
    clip_length: 750
    fps: 25
    load_audio: True
    load_video_s: False
    load_video_l: False
    load_emotion_l: True
    load_emotion_s: True
    load_3dmm_l: True
    load_3dmm_s: True
    load_ref: False
    k_appro: 1
    threads: 32
