from sacred import Experiment

ex = Experiment("CMPT")

@ex.config
def config():
    exp_name = "CMPT"
    seed = 0
    batch_size = 16  
    save_dir = './output'
    device = 'cuda'
    datasets = ["Food101"]

    # Wandb Configurations
    use_wandb = True
    wandb_exp_name = "YOUR_EXP_NAME"
    wandb_project_name = "YOUR_PROJECT_NAME"

    # Eval Configurations 
    test_ratio = None
    test_type = None
    test_exp_name = None
    
    # missing modality config
    missing_ratio = {'train': 0.0, 'val': 0.0, 'test': 0.0}
    missing_type = {'train': 'both', 'val': 'both', 'test': 'both'} # ['text', 'image', 'both'] in VL taskss, Not implemented for AV tasks
    both_ratio = 0.5                                                # Missing both ratio
    missing_table_root = '/Path/to/missing_tables/'
    simulate_missing = False                                        # Enable dropout during training
    model_path = ''                                                 # For evaluation  
        
    # Image setting
    train_transform_keys = [] 
    val_transform_keys = [] 
    image_size = 224
    max_image_len = -1
    draw_false_image = 1

    # Text Setting
    max_text_len = 40
    draw_false_text = 0

    # Optimizer Setting
    optim_type = "adamw"
    learning_rate = 1e-4
    weight_decay = 2e-2
    scheduler = 'warmuppolylr'
    power = 0.9
    warmup = 5
    warmup_ratio = 0.1
    decay_power = 1
    max_epoch = 100
    max_steps = 25000
    warmup_steps = 2500
    end_lr = 0
    lr_mult = 1 

    # Downstream Setting 
    class_num = 23

    # below params varies with the environment
    data_root = "./datasets"
    num_gpus = 1
    gpu_ids = [0]
    num_workers = 2
    precision = 16
    amp = True
    only_paired = False
    text_model = 'bert-base-uncased'

    # LoRA and CL Config
    r = 1
    lora_alpha = 1
    lora_dropout = 0.1
    
    # For MM-IMDb
    # bert_target_modules = ["query", "key", "value", "output.dense"] 

    # For UPMC Food-101
    bert_target_modules = ["query", "key", "value", "attention.output.dense"] 
    
    vit_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] 
    ast_target_modules = ["query", "key", "value", "attention.output.dense"]
    enable_lora = True

    model_name = 'vit'                  # Only for unimodal tasks ["vit", "bert", "ast", "video"]
    enable_mt = False                   # Enable Mask Tokens
    mt_alignment_loss_weight = 0.20     # Lambda value
    data_dir = None                     # Data Directory for AV datasets
    
# UPMC Food-101
@ex.named_config
def task_finetune_food101():
    model_name = 'bert'                         # For unimodal training ['vit', 'bert', 'ast', 'video']
    simulate_missing = True                     # Enable dropout during training
    enable_lora = True                          # Enable LoRA fine-tuning
    enable_mt = True                            # Enable Mask Tokens
    mt_alignment_loss_weight = 0.20             # Lambda Value
    use_wandb = False                           # Enable Wandb for logging
    wandb_exp_name = "FOOD-101"                 # Exp name for Wandb
    exp_name = "finetune_food101"               # Used inside code
    save_dir = './output/Food101/'              # Where checkpoints are saved
    datasets = ["Food101"]
    batch_size = 8
    max_epoch = 10
    max_steps = -1                              # Max steps per epoch, -1 for whole epoch
    draw_false_image = 0
    learning_rate = 1e-3
    max_text_len = 40    
    class_num = 101 
    model_path = '/Path/to/Saved/Checkpoint.pth'
    missing_ratio = {'train': 0.0, 'val': 0.0, 'test': 0.0}         # 0.0 for complete modality, other value for different modality combinations
    missing_type = {'train': 'both', 'val': 'both', 'test': 'both'} # ['text', 'image', 'both'] in VL tasks
    both_ratio = 0.5                                                # Missing both ratio
    missing_table_root = '/Path/to/missing_tables/'
    seed = 0
    data_dir = '/Path/to/Dataset/Root/'
    

# MM-IMDb
@ex.named_config
def task_finetune_mmimdb():
    model_name = 'bert'                         # For unimodal training ['vit', 'bert', 'ast', 'video']
    simulate_missing = True                     # Enable dropout during training
    enable_lora = True                          # Enable LoRA fine-tuning
    enable_mt = True                            # Enable Mask Tokens
    mt_alignment_loss_weight = 0.20             # Lambda Value
    use_wandb = False                           # Enable Wandb for logging
    wandb_exp_name = "MM-IMDb"                  # Exp name for Wandb
    exp_name = "finetune_mmimdb"                # Used inside code
    save_dir = './output/MM-IMDb/'              # Where checkpoints are saved
    datasets = ["mmimdb"]
    batch_size = 8
    max_epoch = 10
    max_steps = -1                              # Max steps per epoch, -1 for whole epoch
    draw_false_image = 0
    learning_rate = 1e-3
    max_text_len = 256    
    class_num = 23 
    model_path = '/Path/to/Saved/Checkpoint.pth'
    missing_ratio = {'train': 0.0, 'val': 0.0, 'test': 0.0}         # 0.0 for complete modality, other value for different modality combinations
    missing_type = {'train': 'both', 'val': 'both', 'test': 'both'} # ['text', 'image', 'both'] in VL tasks
    both_ratio = 0.5                                                # Missing both ratio
    missing_table_root = '/Path/to/missing_tables/'
    seed = 0
    data_dir = '/Path/to/Dataset/Root/'


# KS
@ex.named_config
def task_finetune_kinetics_sound():
    model_name = 'video'
    simulate_missing = True
    enable_lora = True
    enable_mt = True
    mt_alignment_loss_weight = 0.20
    use_wandb = False
    wandb_exp_name = "KS" 
    exp_name = "finetune_ks"
    save_dir = './output/KS/'
    batch_size = 4
    max_epoch = 100
    max_steps = 2000 
    learning_rate = 5e-5 
    class_num = 31
    model_path = '/Path/to/Saved/Checkpoint.pth'
    missing_ratio = {'train': 0.0, 'val': 0.0, 'test': 0.0}         # Not implemented for AV tasks
    missing_type = {'train': 'both', 'val': 'both', 'test': 'both'} # Not implemented for AV tasks
    both_ratio = 0.5                                                # Not implemented for AV tasks
    seed = 0
    data_dir = '/Path/to/Dataset/Root/'
    

# AVE
@ex.named_config
def task_finetune_ave():
    model_name = 'video'
    simulate_missing = True
    enable_lora = True
    enable_mt = True
    mt_alignment_loss_weight = 0.20
    use_wandb = False
    wandb_exp_name = "AVE" 
    exp_name = "finetune_ave"
    save_dir = './output/AVE/'
    batch_size = 4
    max_epoch = 100
    max_steps = -1 
    learning_rate = 5e-5 
    class_num = 28
    model_path = '/Path/to/Saved/Checkpoint.pth'
    missing_ratio = {'train': 0.0, 'val': 0.0, 'test': 0.0}         # Not implemented for AV tasks
    missing_type = {'train': 'both', 'val': 'both', 'test': 'both'} # Not implemented for AV tasks
    both_ratio = 0.5                                                # Not implemented for AV tasks
    seed = 0
    data_dir = '/Path/to/Dataset/Root/'

    
# CREMA-D
@ex.named_config
def task_finetune_cremad():
    model_name = 'video'
    simulate_missing = True
    enable_lora = True
    enable_mt = True
    mt_alignment_loss_weight = 0.20
    use_wandb = False
    wandb_exp_name = "CREMA-D" 
    exp_name = "finetune_cremad"
    save_dir = './output/CREMAD/'
    batch_size = 4
    max_epoch = 100
    max_steps = -1 
    learning_rate = 5e-5 
    class_num = 6
    model_path = '/Path/to/Saved/Checkpoint.pth'
    missing_ratio = {'train': 0.0, 'val': 0.0, 'test': 0.0}         # Not implemented for AV tasks
    missing_type = {'train': 'both', 'val': 'both', 'test': 'both'} # Not implemented for AV tasks
    both_ratio = 0.5                                                # Not implemented for AV tasks
    seed = 0
    data_dir = '/Path/to/Dataset/Root/'
