import datetime
import os
import time
import torch
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter
import sys
from torch.cuda import amp
from models import spiking_cnn_opzo
from modules import neuron, surrogate, functional
import argparse
import torch.utils.data as data
import torchvision.transforms as transforms
from datasets.augmentation import ToPILImage, Resize, Padding, RandomCrop, ToTensor, Normalize
from datasets.cifar10_dvs import CIFAR10DVS
import math
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig

import random
#_seed_ = 2022
#random.seed(_seed_)

#torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
#torch.cuda.manual_seed_all(_seed_)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

import numpy as np
#np.random.seed(_seed_)

def main():

    parser = argparse.ArgumentParser(description='Classify DVS-CIFAR10')
    parser.add_argument('-T', default=10, type=int, help='simulating time-steps')
    parser.add_argument('-tau', default=2., type=float)
    parser.add_argument('-b', default=128, type=int, help='batch size')
    parser.add_argument('-j', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data_dir', type=str, default=None)

    parser.add_argument('-resume', type=str, help='resume from the checkpoint path')

    parser.add_argument('-model', type=str, default='opzo_spiking_cnnws')
    parser.add_argument('-drop_rate', type=float, default=0.1)
    parser.add_argument('-cnf', type=str)
    parser.add_argument('-loss_lambda', type=float, default=0.001)

    parser.add_argument('-not_momentum_feedback', action='store_true', help='not use momentum feedback')
    parser.add_argument('-momentum_fb', default=0.99999, type=float, help='momentum for feedback connections')
    parser.add_argument('-local_loss', action='store_true', help='use local loss')
    parser.add_argument('-DFA', action='store_true', help='use direct feedback alignment')
    parser.add_argument('-p_scale', type=float, default=0.2)
    parser.add_argument('-p_type', type=str, default='Gaussian')

    parser.add_argument('-gpu-id', default='0', type=str, help='gpu id')

    args = parser.parse_args()
    #print(args)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id


    transform_train = transforms.Compose([
        ToPILImage(),
        Resize(48),
        Padding(4),
        RandomCrop(size=48, consistent=True),
        ToTensor(),
        Normalize((0.2728, 0.1295), (0.2225, 0.1290)),
    ])
    
    transform_test = transforms.Compose([
        ToPILImage(),
        Resize(48),
        ToTensor(),
        Normalize((0.2728, 0.1295), (0.2225, 0.1290)),
    ])
    num_classes = 10
    
    trainset = CIFAR10DVS(args.data_dir, train=True, use_frame=True, frames_num=args.T, split_by='number', normalization=None, transform=transform_train)
    train_data_loader = data.DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j)
    
    testset = CIFAR10DVS(args.data_dir, train=False, use_frame=True, frames_num=args.T, split_by='number', normalization=None, transform=transform_test)
    test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j)

    if args.DFA:
        feedback_mode = 'DFA'
    elif args.not_momentum_feedback:
        feedback_mode = 'ZO'
    else:
        feedback_mode = 'PZO'

    num_classes = 10

    net = spiking_cnn_opzo.__dict__[args.model](spiking_neuron=neuron.OPZOLIFNode, tau=args.tau, surrogate_function=surrogate.SigmoidOPZO(alpha=4.), c_in=2, num_classes=num_classes, fc_hw=1, v_reset=None, feedback_mode=feedback_mode, momentum_fb=args.momentum_fb, p_scale=args.p_scale, h_in=48, w_in=48, drop_rate=args.drop_rate, local_loss=args.local_loss)
    net.cuda()



    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        #lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        #start_epoch = checkpoint['epoch'] + 1
        #max_test_acc = checkpoint['max_test_acc']

    for epoch in range(1):
        start_time = time.time()

        net.eval()

        test_samples = 0
        spikes_all = None
        dims = None
        with torch.no_grad():
            for frame, label in test_data_loader:
                frame = frame.float().cuda()
                t_step = args.T

                for t in range(t_step):
                    input_frame = frame[:, t]
                    spikes_batch = net.get_spike(input_frame)
                    if spikes_all is None:
                        spikes_all = []
                        dims = []
                        for i in range(len(spikes_batch)):
                            spikes_all.append(torch.sum(torch.mean(spikes_batch[i], dim=1)).item())
                            dims.append(spikes_batch[i].shape[1])
                    else:
                        for i in range(len(spikes_all)):
                            spikes_all[i] = spikes_all[i] + torch.sum(torch.mean(spikes_batch[i], dim=1)).item()

                functional.reset_net(net)
                test_samples += label.numel()

        for i in range(len(spikes_all)):
            spikes_all[i] = spikes_all[i] / (test_samples * t_step)
        total_rate = 0.
        total_dim = 0
        for i in range(len(spikes_all)):
            total_rate += spikes_all[i] * dims[i]
            total_dim += dims[i]
        total_rate /= total_dim

        for i in range(len(spikes_all)):
            print(f'layer={i+1}, spike_rate={spikes_all[i]}')
        print(f'total_spike_rate={total_rate}')

if __name__ == '__main__':
    main()
