import torch
from torch import nn
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import random
import numpy as np

import argparse

from Preprocess.getdataloader import GetImageNet
from Models.VGG import vgg16
from Models.ResNet import resnet34
from funcs import eval_snn, eval_ann, seed_all
from utils import replace_activation_by_slip, replace_activation_by_neuron, replace_maxpool2d_by_avgpool2d, search_fold_and_remove_bn



def main_tester(rank, gpus, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(backend='gloo', rank=rank, world_size=gpus)
    
    device=f'cuda:{rank}'
    torch.cuda.set_device(device)
    seed_all(args.seed)

    batchsize = int(args.batchsize / gpus)
    train, test = GetImageNet(batchsize)

    # model preparing
    model = resnet34(num_classes=1000)
    model = replace_maxpool2d_by_avgpool2d(model)
    # model = replace_activation_by_slip(model, t=16 )
    # model = replace_activation_by_slip(model, t=args.t, a=args.a, a_learnable=False)
    model = replace_activation_by_slip(model, args.t, args.a, args.shift1, args.shift2, a_learnable=False)
    
    savename = os.path.join(args.checkpoint, args.mid)
    
    model.load_state_dict(torch.load(savename +'.pth'))
    
    model = replace_activation_by_neuron(model, args.shift2)
    
    search_fold_and_remove_bn(model)
    
    model.cuda(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    
    acc = eval_snn(test, model, device=device, sim_len=args.t, rank=rank)
    
    dist.all_reduce(acc)
    acc/=50000
    if rank == 0:
        print(acc)

    dist.destroy_process_group()

def main_anntester(rank, gpus, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(backend='gloo', rank=rank, world_size=gpus)
    
    device=f'cuda:{rank}'
    torch.cuda.set_device(device)
    seed_all(args.seed)

    batchsize = int(args.batchsize / gpus)
    train, test = GetImageNet(batchsize)

    # model preparing
    model = resnet34(num_classes=1000)
    model = replace_maxpool2d_by_avgpool2d(model)
    model = replace_activation_by_slip(model, t=args.t, a=args.a, a_learnable=False)
    
    model.cuda(device)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    
    savename = os.path.join(args.checkpoint, args.mid)
    model.load_state_dict(torch.load(savename +'.pth'))
    
    criterion = nn.CrossEntropyLoss()
    acc = eval_ann(test, model, criterion, device=device, rank=rank)
    
    # acc = eval_ann(test, model, 0, device=device, rank=rank)  # the raw codes
    
    dist.all_reduce(acc)
    
    acc/=50000
    if rank == 0:
        print(acc)
    dist.destroy_process_group()


# multi processing
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--gpus', default=1, type=int, help='GPU number to use.')
    parser.add_argument('--batchsize', default=128, type=int, help='Batchsize')
    parser.add_argument('--t', default=16, type=int, help='Time step')
    
    parser.add_argument('--shift1', default=0.0, type=float, help='The Shift of the threshold-ReLU function')
    parser.add_argument('--shift2', default=0.5, type=float, help='The Shift of the Step function')
    
    parser.add_argument('--mode', default='snn', type=str, help='Test ann or snn')
    parser.add_argument('--mid', default=None, type=str, help='Model identifier')
    parser.add_argument('--checkpoint', default='./saved_models', type=str, help='Directory for saving models')
    
    args = parser.parse_args()
    
    if args.mode == 'snn':
        mp.spawn(main_tester, nprocs=args.gpus, args=(args.gpus, args))
    else:
        mp.spawn(main_anntester, nprocs=args.gpus, args=(args.gpus, args))
