import argparse
import torch
import torch.nn as nn
import random
import os
import numpy as np
from resnet import resnet18, resnet34
from dataprocess import PreProcess_Cifar10, PreProcess_Cifar100, load_ImageNet_dataset
from utils import *
from torch.cuda import amp
from timm.data import Mixup


def eval_one_epoch(model, test_dataloader, sim_len, use_dvs=False):
    model.eval()
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    tot_time, run_times = 0., len(test_dataloader)
    with torch.no_grad():
        for img, label in test_dataloader:
            img = img.to(torch.device('cuda'), non_blocking=True)
            if use_dvs is True:
                img = img.transpose(0, 1).contiguous()
            else:
                img = img.unsqueeze(0).repeat(sim_len, 1, 1, 1, 1)

            for t in range(sim_len):
                starter.record()
                out = model(img[t])
                ender.record()
                torch.cuda.synchronize()
                tot_time += starter.elapsed_time(ender)
                
            reset_model(model)
    
    return tot_time / run_times


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset name')
    parser.add_argument('--datadir', type=str, default='/home/cifar100/', help='Directory where the dataset is saved')
    parser.add_argument('--net_arch', type=str, default='resnet18', help='Network Architecture')
    parser.add_argument('--batchsize', type=int, default=64, help='Batchsize')
    parser.add_argument('--time_step', type=int, default=4, help='Training Time-steps for SNNs')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--dev', type=str, default='0')
    parser.add_argument('--use_eca', type=int, default=0, help='Use ECA Attention')
    parser.add_argument('--use_mem_bn', action='store_true', help='Use Membrane BatchNorm')
    parser.add_argument('--use_parallel', action='store_true', help='Use Parallel Block')


    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.dev
    
    torch.backends.cudnn.benchmark = True
    _seed_ = args.seed
    random.seed(_seed_)
    os.environ['PYTHONHASHSEED'] = str(_seed_)
    torch.manual_seed(_seed_)
    torch.cuda.manual_seed(_seed_)
    torch.cuda.manual_seed_all(_seed_)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(_seed_)

    dvs_data = False
    if args.dataset == 'CIFAR10':
        train_dataloader, test_dataloader, train_sampler, test_sampler = PreProcess_Cifar10(args.datadir, args.batchsize, False)
        cls = 10
        input_size = (3, 32, 32)
    elif args.dataset == 'CIFAR100':
        train_dataloader, test_dataloader, train_sampler, test_sampler = PreProcess_Cifar100(args.datadir, args.batchsize, False)
        cls = 100
        input_size = (3, 32, 32)
    elif args.dataset == 'ImageNet-1k':
        train_dataloader, test_dataloader, train_sampler, test_sampler = load_ImageNet_dataset(args.batchsize, os.path.join(args.datadir, 'train'), os.path.join(args.datadir, 'val'), False)
        cls = 1000
        input_size = (3, 224, 224)
    elif local_rank == 0:
        print('unable to find dataset ' + args.dataset)

        
    if args.net_arch == 'resnet18':
        model = resnet18(args.time_step, num_classes=cls, use_dvs=dvs_data, use_resnet19=False, use_eca=args.use_eca, mem_bn=args.use_mem_bn, parallel_mode=args.use_parallel)
    elif args.net_arch == 'resnet34':
        model = resnet34(args.time_step, num_classes=cls, use_dvs=dvs_data, use_resnet19=False, use_eca=args.use_eca, mem_bn=args.use_mem_bn, parallel_mode=args.use_parallel)
    elif local_rank == 0:
        print('unable to find model ' + args.net_arch)
    
    print(model)
    print_model_param_info(model)
    reset_BConv(model)
    model.cuda()
    tot_time = eval_one_epoch(model, test_dataloader, args.time_step, dvs_data)

    print(f"SNN Inference Time: {tot_time}")
