import torch.nn as nn
from model.AVESFormer import AVESFormer
from model.loss.avs_s4_loss import SegmentationLoss
from dataset.avs_s4_dataset import S4Dataset
from evaluation import MaskIoU

embed_dim = 256

model = AVESFormer(
    img_size=224,
    backbone='resnet18',
    pretrained="your/path/to/resnet18.pth",
    in_channels=[64, 128, 256, 512],
    audio_dim=128,
    embed_dim=embed_dim,
    vggish=dict(
        freeze_audio_extractor=True,
        pretrained_vggish_path="your/path/to/vggish-10086976.pth"
    ),
    num_classes=1,
    query_generator=dict(num_layers=3, num_query=16, embed_dim=embed_dim),
    decoder=dict(num_layers=[2,2,2],
                 layer=dict(
                     dim=256,
                     ffn_dim=1024,
                     dropout=0.0,
                     activation=nn.ReLU))

)

anno_csv = "your/path/to/AVSBench-object/Single-source/s4_meta_data.csv"
dir_img = "your/path/to/AVSBench-object/Single-source/s4_data/visual_frames/"
dir_audio_log_mel = "your/path/to/AVSBench-object/Single-source/s4_data/audio_log_mel/"
dir_mask = "your/path/to/AVSBench-object/Single-source/s4_data/gt_masks/"

train_dataset = S4Dataset(split='train',
                          anno_csv=anno_csv,
                          dir_img=dir_img,
                          dir_audio_log_mel=dir_audio_log_mel,
                          dir_mask=dir_mask
                          )

val_dataset = S4Dataset(split='val',
                        anno_csv=anno_csv,
                        dir_img=dir_img,
                        dir_audio_log_mel=dir_audio_log_mel,
                        dir_mask=dir_mask
                        )

test_dataset = S4Dataset(split='test',
                         anno_csv=anno_csv,
                         dir_img=dir_img,
                         dir_audio_log_mel=dir_audio_log_mel,
                         dir_mask=dir_mask
                         )

loss_fn = SegmentationLoss(weight={
    'iou_loss': 1.8,
    'dice_loss': 1.0,
    'aux_loss': 0.95
}, )

metric = MaskIoU()
