import torch
from torch import nn
import torch.distributed as dist
import torch.multiprocessing as mp
import os

import argparse

from Models import modelpool
from Preprocess.getdataloader import GetImageNet
from funcs import train_ann, seed_all
from utils import replace_activation_by_slip, replace_activation_by_neuron, replace_maxpool2d_by_avgpool2d


def main_worker(rank, gpus, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    
    dist.init_process_group(backend='nccl', rank=rank, world_size=gpus)
    
    device=f'cuda:{rank}'
    torch.cuda.set_device(device)
    
    seed_all(args.seed)
    
    batchsize = int(args.batchsize / gpus)
    train, test = GetImageNet(batchsize)
    
    # model preparing
    model = modelpool(args.model)
    model = replace_maxpool2d_by_avgpool2d(model)
    model = replace_activation_by_slip(model, args.l, args.a, args.shift1, args.shift2, args.a_learnable)
    
    criterion = nn.CrossEntropyLoss()
    
    model.cuda(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    
    savename = os.path.join(args.checkpoint, args.mid)
    
    train_ann(train, test, model, criterion, device, args.epochs, args.lr, args.wd, savename, True, rank)
    
    dist.destroy_process_group()
    
    
    
