import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.clock_driven import functional, layer
import surrogate
import neuron
import encoder as encoder
from torch.utils.tensorboard import SummaryWriter
import time
import numpy as np

from spikingjelly.datasets import play_frame
import matplotlib.pyplot as plt
import models
import os
import json
import init
import datasets
import utils
from collections import defaultdict

from tqdm import tqdm

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

STATIC_DATASETS = ["mnist", 'cifar10']
DATASET_DIR = datasets.DATASET_DIR

# hyperparameter

dataset = 'mnist'  # mnist nmnist cifar10 cifar10dvs
encode_func = 'poisson'  # img, poisson, integrate, event
encoder_param = "{}"

T = 20
thres = 1.0
lam = 0.2
k = 0.8  # k = 1 - 1/tau, k \in (0,1]

runs = "runs"
suffix = 'runs'
# device = 'cuda:0'
batch_size = 64
learning_rate = 1e-3
random_seed = False
manual_seed = 0
train_epoch = 2
init_mode = 'xavier_normal'  # xavier_normal etc
init_param = "{}"
opt = 'adam'  # sgd adam
opt_param = "{\"beta1\":0.9,\"beta2\":0.99,\"weight_decay\":0}"
scheduler = "CosineAnnealingLR"
scheduler_param = "{\"T_max\":300}"
# sample scheduler_param = "{\"T_max\":64}"
grad_func = 'sigmoid'  # sigmoid tanh arctan hardtanh
grad_param = "{}"  # empty is default
# for sigmoid sample grad_param = "{\"alpha\":1.0}"
# for atan sample grad_param = "{\"alpha\":2.0}"
# ...
inspector = True
inspector_interval = 10
inspector_save_full = False  # True: 存tensorboard+log, False: 只存tensorboard
inspector_param = "{\"bin_num\":500,\"need_batch\":50}"  # "{\"bin_num\":50,\"need_batch\":5}" "{\"bin_num\":500,\"need_batch\":50}"
inspector_times = 0
PERCENT = list(range(0, 105, 5))

writer = None
log_dir = ""
train_times = 0
global_id_module_name_dict = defaultdict(str)  # module id -> module name
global_forward_t = defaultdict(int)  # 记录每个模块前向传播的t
global_backward_t = defaultdict(int)  # 记录每个模块前向传播的t
global_accu_linear_output = defaultdict(float)
global_accu_grad_rate_input = defaultdict(float)
global_accu_grad_rate_output = defaultdict(float)
net = None
if_save_model = True


# end of hyperparameter

def reset_training():
    global writer, train_times
    writer = None
    train_times = 0
    reset_hook_param()


def reset_hook_param():
    global global_backward_t, global_forward_t, global_accu_grad_rate_input, \
        global_accu_grad_rate_output, global_accu_linear_output
    global_forward_t = defaultdict(int)
    global_backward_t = defaultdict(int)
    global_accu_grad_rate_input = defaultdict(float)
    global_accu_grad_rate_output = defaultdict(float)
    global_accu_linear_output = defaultdict(float)


def forward_hook(module, input, output):
    global global_forward_t, T, train_times, global_accu_linear_output
    module_name = global_id_module_name_dict[id(module)]
    if module.__class__.__name__ in ['Linear', 'Conv2d']:  # 如果是线性层
        global_accu_linear_output[module_name] += output.data.to('cpu').numpy()
    pass


def backward_hook(module, grad_input, grad_output):
    global global_backward_t, T, train_times, global_accu_grad_rate_input, global_accu_grad_rate_output, \
        global_id_module_name_dict
    module_name = global_id_module_name_dict[id(module)]
    if hasattr(module, 'v_threshold'):  # 如果是神经元
        global_accu_grad_rate_input[module_name] += grad_output[0].data.to('cpu').numpy()
        global_accu_grad_rate_output[module_name] += grad_input[0].data.to('cpu').numpy()


def main():
    global dataset, encode_func, \
        T, thres, lam, k, \
        batch_size, learning_rate, random_seed, manual_seed, train_epoch, init_mode, \
        opt, opt_param, inspector, inspector_param, inspector_save_full, inspector_interval, scheduler, scheduler_param, grad_func, grad_param, init_param

    global encoder_param

    global global_backward_t, global_forward_t, global_id_module_name_dict

    global writer, train_times, suffix

    global global_accu_grad_rate_input, global_accu_grad_rate_output, global_accu_linear_output

    global log_dir

    global net

    # perpare
    if dataset in STATIC_DATASETS:
        assert (encode_func.lower() != 'event')
    else:
        assert (encode_func.lower() == 'event')

    tau = (1 / (1 - k)) if np.abs(k - 1) > 1e-3 else np.inf
    # print(tau)

    dataset_dir = DATASET_DIR[dataset.lower()]

    if random_seed:
        seed = np.random.randint(0, 10000)
    else:
        seed = manual_seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    # end of prepare

    d = {
        'suffix': suffix,  # 'device': device,
        'dataset': dataset,
        'dataset_dir': dataset_dir,
        'encode_func': encode_func,
        'encoder_param': encoder_param,
        'batch_size': batch_size,
        'learning_rate': learning_rate,
        'T': T,
        'tau': tau,
        'k': k,
        'lam': lam,
        'threshold': thres,
        'init_mode': init_mode,
        'init_param': init_param,
        'train_epoch': train_epoch,
        'opt': opt,
        'opt_param': opt_param,
        'random_seed': random_seed,
        'seed': seed,
        "scheduler": scheduler,
        'scheduler_param': scheduler_param,
        'grad_func': grad_func,
        'grad_param': grad_param,
        'inspector': inspector,
        'inspector_param': inspector_param,
        'inspector_save_full': inspector_save_full,
        'inspector_interval': inspector_interval,
        # 'net': ('basic5layer' if net is None else net.__class__.__name__)
    }

    opt_param_dict = json.loads(opt_param)
    scheduler_param_dict = json.loads(scheduler_param)
    encoder_param_dict = json.loads(encoder_param)

    # 初始化数据加载器
    train_data_loader, test_data_loader = datasets.load_dataset(dataset, dataset_dir, batch_size, T=T)

    if 'EXP_' + grad_func in dir(surrogate):
        grad_f = getattr(surrogate, 'EXP_' + grad_func)
    else:
        raise ValueError()

    # 初始化网络
    if net is None:  # 如果没有预先设计网络， 使用默认网络
        if dataset == 'mnist':
            net = models.MNIST_5layer(k=k, lam=lam, v_threshold=thres, T=T, grad=grad_f, grad_kargs=grad_param).cuda()
        elif dataset == 'nmnist':
            net = models.NMNIST_5layer(k=k, lam=lam, v_threshold=thres, T=T, grad=grad_f, grad_kargs=grad_param).cuda()
        elif dataset == 'cifar10':
            net = models.CIFAR10_5layer(k=k, lam=lam, v_threshold=thres, T=T, grad=grad_f, grad_kargs=grad_param).cuda()
        elif dataset == 'cifar10dvs':
            net = models.CIFAR10DVS_5layer(k=k, lam=lam, v_threshold=thres, T=T, grad=grad_f,
                                           grad_kargs=grad_param).cuda()
        else:
            raise ValueError()
    else:
        net.cuda()

    init_kargs = json.loads(init_param)
    init_kargs['thres'] = thres
    init_kargs['lam'] = lam
    init_kargs['k'] = k
    # print('\ninit:',init_mode)
    module_id = 0
    for n, m in net.named_modules():
        global_id_module_name_dict[id(m)] = '%d-%s-%s' % (module_id, n, m.__class__.__name__)
        # print(m.__class__.__name__)
        if hasattr(m, 'weight') and m.__class__.__name__ in ['Linear', 'Conv2d']:
            module_id += 1
            getattr(init.init, init_mode)(m, **init_kargs)
        elif hasattr(m, 'v_threshold'):
            module_id += 1
        elif 'BatchNorm' in m.__class__.__name__:  # better without batchnorm
            nn.init.normal_(m.weight.data, 1.0)
            nn.init.constant_(m.bias.data, 0.0)

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d%H%M%S')
    log_dir = '%s-%s-%s-k%.3f-lam%.3f-thres%.3f-%s-%s-[%s]' % (
    suffix, dataset, init_mode.upper(), k, lam, thres, encode_func, opt.upper(), current_time)
    # log_dir = 'debuging' # todo
    if not os.path.exists(os.path.join(runs, log_dir)):
        os.makedirs(os.path.join(runs, log_dir))

    with open(os.path.join(runs, log_dir, 'args.json'), 'w') as fp:
        json.dump(d, fp)

    writer = SummaryWriter(log_dir=os.path.join(runs, log_dir))
    print('log_dir:', os.path.join(runs, log_dir))
    print('param:')
    for key, value in d.items():
        print('{key}:{value}'.format(key=key, value=value))

    # 优化器
    if opt.lower() == 'adam':
        assert (
                    'beta1' in opt_param_dict.keys() and 'beta2' in opt_param_dict.keys() and 'weight_decay' in opt_param_dict.keys())
        betas = (opt_param_dict['beta1'], opt_param_dict['beta2'])
        del opt_param_dict['beta1']
        del opt_param_dict['beta2']
        opt_param_dict['betas'] = betas
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, **opt_param_dict)
    elif opt.lower() == 'sgd':
        assert ('momentum' in opt_param_dict.keys() and 'weight_decay' in opt_param_dict.keys())
        optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, **opt_param_dict)
    else:
        raise ValueError()

    lr_scheduler = getattr(torch.optim.lr_scheduler, scheduler)(optimizer, **scheduler_param_dict)

    if encode_func.lower() == 'poisson':
        enc = encoder.PoissonEncoder(T=T, **encoder_param_dict)
    elif encode_func.lower() == 'img':
        enc = encoder.ImgEncoder(T=T, **encoder_param_dict)
    elif encode_func.lower() == 'integrate':
        enc = encoder.IntegratingEncoder(T=T, **encoder_param_dict)
    elif encode_func.lower() == 'event':
        enc = encoder.EventEncoder(T=T, **encoder_param_dict)
    elif encode_func.lower() == "normedpoisson":
        enc = encoder.NormalizedPoissonEncoder(T=T, **encoder_param_dict)
    else:
        raise ValueError()

    max_test_accuracy = 0
    for epoch in range(train_epoch):

        if inspector:
            forward_handles = []
            backward_handles = []
            for n, m in net.named_modules():
                if m.__class__.__name__ in ['Linear', 'Conv2d'] or 'LIFNode' in m.__class__.__name__:
                    forward_handles.append(m.register_forward_hook(forward_hook))
                    backward_handles.append(m.register_backward_hook(backward_hook))

        net.train()
        if inspector:
            functional.set_monitor(net, True)
        t_start = time.perf_counter()
        total_time = 0
        for i, (img, label) in enumerate(train_data_loader):
            start_time = time.time()
            if i % inspector_interval == 0 and inspector:
                # print('running histo_for_training_data, this may take a long time...')
                inspector_param_dict = json.loads(inspector_param)
                if inspector_save_full:
                    histo_for_training_data(net, train_data_loader, enc, **inspector_param_dict)  # hia hia hia hia!!!

            img = img.cuda()
            label = label.cuda()
            label_one_hot = F.one_hot(label, 10).float()

            if inspector:  # before forward get weight&bias mean std hist
                reset_hook_param()
                for n, m in net.named_modules():
                    if m.__class__.__name__ in ['Linear', 'Conv2d']:
                        n_ = global_id_module_name_dict[id(m)]
                        writer.add_scalar('%s/weight-mean' % n_, torch.mean(m.weight.data), global_step=train_times)
                        writer.add_scalar('%s/weight-std' % n_, torch.std(m.weight.data), global_step=train_times)
                        writer.add_histogram('%s/weight-histo' % n_, m.weight.data, global_step=train_times)

                        if hasattr(m, 'bias'):
                            writer.add_scalar('%s/bias-mean' % n_, torch.mean(m.bias.data), global_step=train_times)
                            writer.add_scalar('%s/bias-std' % n_, torch.std(m.bias.data), global_step=train_times)
                            writer.add_histogram('%s/bias-histo' % n_, m.bias.data, global_step=train_times)

            optimizer.zero_grad()

            if inspector:
                global_forward_t = defaultdict(int)
                global_accu_linear_output = defaultdict(float)

            encoding = enc.encode(img)
            out_spikes_counter_frequency = net(encoding)
            enc.reset()

            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)

            if inspector:
                global_backward_t = defaultdict(int)
                global_accu_grad_rate_input = defaultdict(float)
                global_accu_grad_rate_output = defaultdict(float)

            loss.backward()

            if inspector:  # after forward get weight&bias grad mean std hist
                for n, m in net.named_modules():
                    if m.__class__.__name__ in ['Linear', 'Conv2d']:
                        n_ = global_id_module_name_dict[id(m)]
                        writer.add_scalar('%s/weight_grad-mean' % n_, torch.mean(m.weight.grad.data),
                                          global_step=train_times)
                        writer.add_scalar('%s/weight_grad-std' % n_, torch.std(m.weight.grad.data),
                                          global_step=train_times)
                        writer.add_histogram('%s/weight_grad-histo' % n_, m.weight.grad.data, global_step=train_times)
                        if hasattr(m, 'bias'):
                            writer.add_scalar('%s/bias_grad-mean' % n_, torch.mean(m.bias.grad.data),
                                              global_step=train_times)
                            writer.add_scalar('%s/bias_grad-std' % n_, torch.std(m.bias.grad.data),
                                              global_step=train_times)
                            writer.add_histogram('%s/bias_grad-histo' % n_, m.bias.grad.data, global_step=train_times)

                        # output statistics
                        avg_linear_output = torch.from_numpy(global_accu_linear_output[n_]) / T
                        writer.add_scalar('%s/avg_linear_output-mean' % n_, torch.mean(avg_linear_output.data),
                                          global_step=train_times)
                        writer.add_scalar('%s/avg_linear_output-std' % n_, torch.std(avg_linear_output.data),
                                          global_step=train_times)
                        writer.add_histogram('%s/avg_linear_output-histo' % (n_),
                                             avg_linear_output.data,
                                             global_step=train_times)

                    if hasattr(m, 'v_threshold'):
                        n_ = global_id_module_name_dict[id(m)]
                        # rate statistics
                        array = np.stack(m.monitor['s'])
                        rate = torch.from_numpy(np.sum(array, axis=0)) / T  # 脉冲累积
                        writer.add_scalar('%s/rate-mean' % (n_), torch.mean(rate.data),
                                          global_step=train_times)
                        writer.add_scalar('%s/rate-std' % (n_), torch.std(rate.data), global_step=train_times)
                        writer.add_histogram('%s/rate-histo' % (n_), rate.data, global_step=train_times)

                        # grad statistics
                        accu_grad_rate_input = torch.from_numpy(global_accu_grad_rate_input[n_])
                        accu_grad_rate_output = torch.from_numpy(global_accu_grad_rate_output[n_])
                        writer.add_scalar('%s/accu_grad_rate_input-mean' % (n_),
                                          torch.mean(accu_grad_rate_input.data),
                                          global_step=train_times)
                        writer.add_scalar('%s/accu_grad_rate_input-std' % (n_),
                                          torch.std(accu_grad_rate_input.data),
                                          global_step=train_times)
                        writer.add_histogram('%s/accu_grad_rate_input-histo' % (n_), accu_grad_rate_input.data,
                                             global_step=train_times)

                        writer.add_scalar('%s/accu_grad_rate_output-mean' % (n_),
                                          torch.mean(accu_grad_rate_output.data),
                                          global_step=train_times)
                        writer.add_scalar('%s/accu_grad_rate_output-std' % (n_),
                                          torch.std(accu_grad_rate_output.data),
                                          global_step=train_times)
                        writer.add_histogram('%s/accu_grad_rate_output-histo' % (n_),
                                             accu_grad_rate_output.data,
                                             global_step=train_times)

                global_backward_t = defaultdict(int)
                global_forward_t = defaultdict(int)
                global_accu_grad_rate_input = defaultdict(float)
                global_accu_grad_rate_output = defaultdict(float)
                global_accu_linear_output = defaultdict(float)

            optimizer.step()

            functional.reset_net(net)

            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.cuda()).float().mean().item()
            writer.add_scalar('train_accuracy', accuracy, train_times)
            writer.add_scalar('train_loss', loss.item(), train_times)
            train_times += 1

            elapse_time = time.time() - start_time
            total_time += elapse_time
            print('\r[training] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]' % (
                float(i + 1) / int(len(train_data_loader)) * 100, i + 1, int(len(train_data_loader)),
                progress_bar_time(total_time),
                progress_bar_time(total_time / (i + 1) * int(len(train_data_loader))),
                total_time / (i + 1)), end='')

        t_train = time.perf_counter() - t_start

        if inspector:
            for h in forward_handles:
                h.remove()
            for h in backward_handles:
                h.remove()

        lr_scheduler.step()

        print('')
        net.eval()
        if inspector:
            functional.set_monitor(net, False)
        t_start = time.perf_counter()
        with torch.no_grad():
            test_sum = 0
            correct_sum = 0
            total_time = 0
            for i, (img, label) in enumerate(test_data_loader):
                start_time = time.time()
                img = img.cuda()
                encoding = enc.encode(img)
                out_spikes_counter_frequency = net(encoding)
                enc.reset()

                correct_sum += (out_spikes_counter_frequency.max(1)[1] == label.cuda()).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)

                elapse_time = time.time() - start_time
                total_time += elapse_time
                print('\r[evaluating] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]' % (
                    float(i + 1) / int(len(test_data_loader)) * 100, i + 1, int(len(test_data_loader)),
                    progress_bar_time(total_time),
                    progress_bar_time(total_time / (i + 1) * int(len(test_data_loader))),
                    total_time / (i + 1)), end='')
            test_accuracy = correct_sum / test_sum
            t_test = time.perf_counter() - t_start
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            if max_test_accuracy < test_accuracy:
                max_test_accuracy = test_accuracy
                print('saving net...')
                if if_save_model:
                    torch.save(net, os.path.join(runs, log_dir) + '/net_max_acc.pt')
                print('saved')
        print('Epoch %d/%d TrainTime %.3f TestTime %.3f TestAcc %.2f MaxAcc %.2f' % (
        epoch, train_epoch, t_train, t_test, test_accuracy, max_test_accuracy))


def progress_bar_time(total_time):
    hour = int(total_time) // 3600
    minu = (int(total_time) % 3600) // 60
    sec = int(total_time) % 60
    return '%d:%02d:%02d' % (hour, minu, sec)


def histo_for_training_data(net, train_data_loader, enc, bin_num=500, need_batch=50, logfile='inspector.log'):
    print('\n')
    global global_backward_t, global_forward_t, global_id_module_name_dict
    global global_accu_grad_rate_input, global_accu_grad_rate_output, global_accu_linear_output
    global log_dir, inspector_times
    global PERCENT
    global T

    hist_dict = defaultdict(list)

    forward_handles = []
    backward_handles = []
    for n, m in net.named_modules():
        if m.__class__.__name__ in ['Linear', 'Conv2d'] or 'LIFNode' in m.__class__.__name__:
            backward_handles.append(m.register_backward_hook(backward_hook))
            backward_handles.append(m.register_forward_hook(forward_hook))

    net.train()
    functional.set_monitor(net, True)
    total_time = 0
    for i, (img, label) in enumerate(train_data_loader):
        start_time = time.time()
        if i == need_batch:
            break
        img = img.cuda()
        label = label.cuda()
        label_one_hot = F.one_hot(label, 10).float()

        reset_hook_param()
        for n, m in net.named_modules():
            if m.__class__.__name__ in ['Linear', 'Conv2d']:
                n_ = global_id_module_name_dict[id(m)]

                n, bin_edges = np.histogram(m.weight.data.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/weight' % (n_)].append((v[k], n[k]))

                n, bin_edges = np.histogram(m.bias.data.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/bias' % (n_)].append((v[k], n[k]))

        global_forward_t = defaultdict(int)
        global_accu_linear_output = defaultdict(float)

        encoding = enc.encode(img)
        out_spikes_counter_frequency = net(encoding)
        enc.reset()

        loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)

        global_backward_t = defaultdict(int)
        global_accu_grad_rate_input = defaultdict(float)
        global_accu_grad_rate_output = defaultdict(float)

        loss.backward()

        for n, m in net.named_modules():
            if m.__class__.__name__ in ['Linear', 'Conv2d']:
                n_ = global_id_module_name_dict[id(m)]
                n, bin_edges = np.histogram(m.weight.grad.data.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/weight_grad' % (n_)].append((v[k], n[k]))
                if hasattr(m, 'bias'):
                    n, bin_edges = np.histogram(m.bias.grad.data.to('cpu').data.numpy(), bins=bin_num)
                    v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                    for k in range(len(n)):
                        hist_dict['%s/bias_grad' % (n_)].append((v[k], n[k]))

                avg_linear_output = torch.from_numpy(global_accu_linear_output[n_]) / T
                n, bin_edges = np.histogram(avg_linear_output.data.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/avg_linear_output' % (n_)].append((v[k], n[k]))

            if hasattr(m, 'v_threshold'):
                n_ = global_id_module_name_dict[id(m)]
                # rate statics
                array = np.stack(m.monitor['s'])
                rate = torch.from_numpy(np.sum(array, axis=0)) / T  # 脉冲累积
                n, bin_edges = np.histogram(rate.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/rate' % (n_)].append((v[k], n[k]))

                # grad statics
                accu_grad_rate_input = torch.from_numpy(global_accu_grad_rate_input[n_])
                accu_grad_rate_output = torch.from_numpy(global_accu_grad_rate_output[n_])
                n, bin_edges = np.histogram(accu_grad_rate_input.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/accu_grad_rate_input' % (n_)].append((v[k], n[k]))
                n, bin_edges = np.histogram(accu_grad_rate_output.to('cpu').data.numpy(), bins=bin_num)
                v = (bin_edges[:bin_num] + bin_edges[1:]) / 2
                for k in range(len(n)):
                    hist_dict['%s/accu_grad_rate_output' % (n_)].append((v[k], n[k]))

        global_backward_t = defaultdict(int)
        global_forward_t = defaultdict(int)
        global_accu_grad_rate_input = defaultdict(float)
        global_accu_grad_rate_output = defaultdict(float)
        global_accu_linear_output = defaultdict(float)

        functional.reset_net(net)

        elapse_time = time.time() - start_time
        total_time += elapse_time
        print('\r[saving statistic] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]' % (
        float(i + 1) / int(need_batch) * 100, i + 1, int(need_batch),
        progress_bar_time(total_time),
        progress_bar_time(total_time / (i + 1) * int(need_batch)),
        total_time / (i + 1)), end='')

    if inspector_times == 0:
        with open(os.path.join(runs, log_dir, logfile), 'w') as fp:
            fp.write('%s\n' % (bin_num))

    print('\nsaving to %s ...' % logfile)
    total_time = 0
    key_len = len(hist_dict.keys())
    for f, k in enumerate(hist_dict.keys()):
        start_time = time.time()
        l = []
        for (v, n) in hist_dict[k]:
            l += ([v] * n)
        mean = np.mean(l)
        std = np.std(l)

        pv = []
        for p in PERCENT:
            if p == 0:
                pv.append(np.min(l))
            elif p == 100:
                pv.append(np.max(l))
            else:
                pv.append(np.percentile(l, p))
        nn, bin_edges = np.histogram(l, bins=bin_num)
        csv_string = ','.join(np.array([inspector_times, k, mean, std, *pv, *nn, *bin_edges]))

        with open(os.path.join(runs, log_dir, logfile), 'a') as fp:
            fp.write('%s\n' % (csv_string))

        elapse_time = time.time() - start_time
        total_time += elapse_time
        print('\r[saving file] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]' % (
            float(f + 1) / int(key_len) * 100, f + 1, int(key_len),
            progress_bar_time(total_time),
            progress_bar_time(total_time / (f + 1) * int(key_len)),
            total_time / (f + 1)), end='')

    print('\n')
    inspector_times += 1
    for h in forward_handles:
        h.remove()
    for h in backward_handles:
        h.remove()


def load_inspector_log(log_dir='runs-mnist-SNN_INIT_1_UNIFORM-k0.800-lam0.200-thres1.000-poisson-ADAM-[Jun26145946]',
                       logfile='inspector.log'):
    # global log_dir

    f = open(os.path.join(runs, log_dir, logfile), 'r')
    cnt = 0
    while True:
        info = f.readline()
        if not info:
            break
        else:
            cnt += 1
        if cnt == 1:
            bin_num = int(info)
            # print(bin_num)
        else:
            l = info.split(',')
            if len(l) != (4 + len(PERCENT) + bin_num * 2 + 1):
                raise Warning('some data has corrupted...')
            inspector_times = int(l[0])
            name = l[1]
            mean = float(l[2])
            std = float(l[3])
            percentile = [float(i) for i in l[4:4 + len(PERCENT)]]
            n = [int(float(i)) for i in l[4 + len(PERCENT):4 + len(PERCENT) + bin_num]]
            bin_edges = [float(i) for i in l[4 + len(PERCENT) + bin_num:]]

            print()
            print(inspector_times)
            print(name)
            print(mean)
            print(std)
            print(len(percentile))

            if name == '6-fc.3-Linear/avg_linear_output':
                plt.plot(bin_edges[1:], n)
                plt.show()
            print(len(n))
            print(len(bin_edges))

            # plot...
    f.close()


if __name__ == '__main__':
    main()
    # load_inspector_log()
