from models.arch.dsit_final import DSIT_Final


def dsit_final_m(in_channels=3, out_channels=3):
    enc_blk_nums = [12, 8, 4, 2, 2]
    dec_blk_nums = [2, 2, 2, 2, 2]

    return DSIT_Final(input_resolution=(384, 384), window_size=12,
                      enc_blk_nums=enc_blk_nums, dec_blk_nums=dec_blk_nums)


if __name__ == '__main__':
    import torch

    x = torch.ones(1, 3, 384, 384).cuda()
    model = dsit_final_m(3, 3).cuda()
    print(model)
    print(model(x))
