import torch.nn as nn
from model.loss.avss_loss import SegmentationLoss
from model.AVESFormer import AVESFormer
from dataset.avss_dataset import V2Dataset
from evaluation import ColorMiou

mask_num = 10
meta_cvs_path = "your/path/to/AVSBench-semantic/metadata.csv"
label_idx_path = "your/path/to/AVSBench-semantic/label2idx.json"
dir_base = "your/path/to/AVSBench-semantic/"
crop_img_and_mask = True
crop_size = 224
img_size = (224, 224)
num_classes = 71

embed_dim = 256

model = AVESFormer(
    img_size=224,
    backbone='resnet18',
    pretrained="your/path/to/resnet18.pth",
    in_channels=[64, 128, 256, 512],
    audio_dim=128,
    embed_dim=embed_dim,
    vggish=dict(
        freeze_audio_extractor=True,
        pretrained_vggish_path="your/path/to/vggish-10086976.pth"
    ),
    num_classes=num_classes,
    query_generator=dict(num_layers=3, num_query=128, embed_dim=embed_dim),
    decoder=dict(num_layers=[2,2,2],
                 layer=dict(
                     dim=256,
                     ffn_dim=1024,
                     dropout=0.0,
                     activation=nn.ReLU,))

)

train_dataset = V2Dataset(
    mask_num=mask_num,
    meta_cvs_path=meta_cvs_path,
    label_idx_path=label_idx_path,
    num_classes=num_classes,
    dir_base=dir_base,
    crop_img_and_mask=crop_img_and_mask,
    crop_size=crop_size,
    split='train'
)

val_dataset = V2Dataset(
    mask_num=mask_num,
    meta_cvs_path=meta_cvs_path,
    label_idx_path=label_idx_path,
    num_classes=num_classes,
    dir_base=dir_base,
    crop_img_and_mask=crop_img_and_mask,
    crop_size=crop_size,
    split='test'
)

test_dataset = V2Dataset(
    mask_num=mask_num,
    meta_cvs_path=meta_cvs_path,
    label_idx_path=label_idx_path,
    num_classes=num_classes,
    dir_base=dir_base,
    crop_img_and_mask=crop_img_and_mask,
    crop_size=crop_size,
    split='test'
)

loss_fn = SegmentationLoss(
    weight={
        'iou_loss': 1.0,
        'aux_loss':0.1
    }
)

metric = ColorMiou()
