from dscnn_ff import run_dscnn_giff
from dscnn_ff import Opts
if __name__ == "__main__":

    opts = Opts()

    config = dict(
        ds_channels =   [[64,64],[64,64],[64,64],[64,64]],
        ds_pooling_2 =  [],
        ds_pooling_3 =  [],
        ds_pooling_4 =  [],
        ds_pooling_5 =  [],
        FF_block_nums = 0,
    )

    # general settings
    thetas = [4, 6, 8, 12, 16]
    opts.online_visual = 1
    opts.epochs = 100
    opts.device = 'cuda:0'

    # pooling settings
    ds_poolings = [
        [
        [[1,1],[1,1]],
        [[1,1],[3,3]],
        [[1,1],[5,5]],
        [[1,1],[7,7]],
        [[3,3],[3,3]],
        [[3,3],[5,5]],
        [[3,3],[7,7]],
    ],
        [
        [[1,1],[1,1],[1,1]],
        [[1,1],[1,1],[3,3]],
        [[1,1],[1,1],[5,5]],
        [[1,1],[1,1],[7,7]],
        [[1,1],[3,3],[3,3]],
        [[1,1],[3,3],[5,5]],
        [[1,1],[3,3],[7,7]],
    ],
    [
        [[1,1],[1,1],[1,1],[1,1]],
        [[3,3],[1,1],[1,1],[1,1]],
        [[3,3],[1,1],[3,3],[1,1]],
        [[1,1],[3,3],[3,3],[1,1]],
        [[3,3],[3,3],[3,3],[3,3]],
        [[3,3],[1,1],[1,1],[3,3]],
        [[7,7],[5,5],[3,3],[1,1]],
    ]
    ]


    
    for  block in [2,2.5,3,3.5,4]:
        config['FF_block_nums'] = block
        opts.project_name = "forward-forward-benchmark-temporal-GIFF"+str(config['FF_block_nums'])
        opts.weight_decay  = 5e-3
        ds_pooling = ds_poolings[int(block)]
        config_pooling = 'ds_pooling_' + str(int(block))
        for theta in thetas:
            for pooling in ds_pooling:
                    config[config_pooling] = pooling
                    opts.theta = theta
                    opts.runtime_name = str(config['FF_block_nums'])+str(pooling) + "_theta_" + str(theta) + "initial_lr" + str(opts.lr) +\
                        "warmup_epochs" + str(opts.warmup_epochs) + "weight_decay" + str(opts.weight_decay)
                    run_dscnn_giff(opts,config)
        opts.weight_decay  = 5e-2
        for theta in thetas:
            for pooling in ds_pooling:
                    config[config_pooling] = pooling
                    opts.theta = theta
                    opts.runtime_name = str(config['FF_block_nums'])+str(pooling) + "_theta_" + str(theta) + "initial_lr" + str(opts.lr) +\
                        "warmup_epochs" + str(opts.warmup_epochs) + "weight_decay" + str(opts.weight_decay)
                    run_dscnn_giff(opts,config)