import argparse
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import os
import numpy as np
import time
import random
import torch.backends.cudnn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import sys
sys.path.append('../../')
from iclr23code import utils, dataset, backprop, modules, surrogate, LIF, NLIF


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


start_time = time.time()

parser = argparse.ArgumentParser()
parser.add_argument('--device_ids', type=str, default='0, 1', help='Training Device')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter')

parser.add_argument('--data_path', default='/data_smr/dataset/cifar-dvs', type=str, help='Path of data')
parser.add_argument('--dataset', default='AUGDVS', type=str, help='Dataset name',
                    choices=['CIFAR10', 'CIFAR100', 'IMAGE', 'AUGDVS'])
parser.add_argument('--batch_size', default=36, type=int, help='Batch size')

parser.add_argument('--neuron', default='nlif', type=str, help='Tyep of neuron',
                    choices=['lif'])
parser.add_argument('--num_step', default=10, type=int, help='Time dimension')
parser.add_argument('--bn', default='tdbn', type=str, help='Type of bn layer in model',
                    choices=['nn', 'straight', 'tdbn'])
parser.add_argument('--spike_func', default='triangle', type=str, help='Surrogate function',
                    choices=['rectangular', 'triangle', 'actan'])
parser.add_argument('--slope', default=1.0, type=float, help='Parameter used in surrogate')
parser.add_argument('--threshold', default=1.0, type=float, help='Membrane threshold')
parser.add_argument('--weak_mem', default=0.5, type=float, help='Weaken rate of mem')
parser.add_argument('--reset_mechanism', default='zero', type=str,
                    help='Membrane reset mechanism', choices=['subtract', 'zero'])

parser.add_argument('--model', default='vgg11', type=str, help='Architecture of model',
                    choices=['vgg16', 'tdresnet19c10', 'sewresnet34', 'dspike', 'vgg11'])

parser.add_argument('--back_method', default='trainEn', type=str, help='Backward method',
                    choices=['train', 'trainImg', 'trainEn'])
parser.add_argument('--pre_process', default='logSoft', type=str, help='Pre process if result',
                    choices=['logSoft', 'soft', 'none'])
parser.add_argument('--adjust', default='bna', type=str, help='adjust class for slope',
                    choices=['d', 'bna'])
parser.add_argument('--key', default='0', type=str, help='idx of inner_list')
parser.add_argument('--momentum', default=0.0, type=float, help='Momentum in adjust class')
parser.add_argument('--limit_k', default=0.2, type=float, help='limit_k in adjust class')

parser.add_argument('--code_method', default='dvs', type=str, help='Coding method',
                    choices=['copy', 'dvs'])
parser.add_argument('--updater', default='img', type=str, help='For lr update',
                    choices=['warm', 'img'])

parser.add_argument('--lr', default=1e-3, type=float, help='Max learning rate')
parser.add_argument('--num_epoch', default=200, type=int, help='Training epodh')
parser.add_argument('--best_acc', default=0, type=float, help='Best acc of current model.')
parser.add_argument('--optimizer', default='adam', type=str,
                    help='If use adam, set a lower lr(0.01)', choices=['adam', 'sgd'])
parser.add_argument('--appendix', default='0', type=str, help='Append information')
args = parser.parse_args()

set_seed((args.local_rank + 1) * 2000)

device_ids = list(map(int, args.device_ids.split(',')))
gpu_num = len(device_ids)
print(f'Current gup size: {gpu_num}')

dist.init_process_group(backend='nccl')
device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))
torch.cuda.set_device(device)

if args.neuron == 'lif':
    neuron = LIF
elif args.neuron == 'nlif':
    neuron = NLIF
else:
    raise ValueError("Can't find choice in args.neuron.")

if args.bn == 'nn':
    bn = modules.batch_norm2d
elif args.bn == 'straight':
    bn = modules.straight
elif args.bn == 'tdbn':
    bn = modules.td_bn
else:
    raise ValueError("Can't find choice in args.bn.")

if args.spike_func == 'rectangular':
    spike_func = surrogate.rectangular
elif args.spike_func == 'triangle':
    spike_func = surrogate.triangle
elif args.spike_func == 'actan':
    spike_func = surrogate.actan
else:
    raise ValueError("Can't find choice in args.spike_func.")

if args.model == 'vgg16':
    model = modules.NBAvgVgg16
    recorder_size = 15
elif args.model == 'tdresnet19c10':
    recorder_size = 18
    model = modules.TdResNet19C10
elif args.model == 'sewresnet34':
    print("Note: If use model `sewresnet34`, remember to use IF neuron(set weak_mem to 1.0).")
    recorder_size = 36
    model = modules.SewResNet34
elif args.model == 'dspike':
    recorder_size = 19
    model = modules.DspikeModel
elif args.model == 'vgg11':
    recorder_size = 8
    model = modules.VGG11
else:
    raise ValueError("Can't fine choice in args.model.")

net = model(neuron, args.num_step, bn, spike_func, args.slope, args.threshold, args.weak_mem,
            args.reset_mechanism, args.dataset).to(device)

train_data, test_data = dataset.dataload(args.data_path, args.dataset)
sampler_train = DistributedSampler(train_data)
train_loader = DataLoader(train_data, batch_size=args.batch_size, sampler=sampler_train,
                          drop_last=True, pin_memory=True)
sampler_test = DistributedSampler(test_data)
test_loader = DataLoader(test_data, batch_size=args.batch_size, sampler=sampler_test,
                         pin_memory=True)

if args.back_method == 'train':
    back_method = backprop.train
elif args.back_method == 'trainImg':
    back_method = backprop.trainImg
elif args.back_method == 'trainEn':
    back_method = backprop.train_en
else:
    raise ValueError("Can't find choice in args.back_method.")

if args.pre_process == 'None':
    pre_process = None
elif args.pre_process == 'logSoft':
    pre_process = nn.LogSoftmax(dim=-1)
elif args.pre_process == 'soft':
    pre_process = nn.Softmax(dim=-1)
else:
    raise ValueError("Can't find choice in args.pre_process.")

limit = 0
if args.updater == 'warm':
    updater = utils.WarmUpdate(args.num_epoch, args.lr)
elif args.updater == 'img':
    updater = utils.ImgUpdate(args.num_epoch, args.lr)
else:
    raise ValueError("Can't find choice in args.updater.")

if args.adjust == 'd':
    adjust = backprop.Adjust(limit)
elif args.adjust == 'bna':
    adjust = backprop.BNAdjust(limit, args.key, args.momentum, args.limit_k)
else:
    raise ValueError("Can't find choice in args.adjust.")

if args.code_method == 'copy':
    code_method = utils.copy_code
elif args.code_method == 'dvs':
    code_method = utils.dvs_code
else:
    raise ValueError("Can't find choice in args.code_method.")

work_dir = './result/'
os.makedirs(work_dir, exist_ok=True)
prefix_save_name = (work_dir + args.dataset + '_' + args.code_method + '_' + str(args.num_step) +
                    '_' + args.back_method + '_' + args.updater + '_' + args.pre_process + '_' +
                    args.adjust + '_' + args.key + '_' + str(args.momentum) + '_' +
                    str(args.limit_k) + '_' + args.model + '_' + args.neuron + '_' + args.bn +
                    '_' + args.spike_func + '_' + str(args.slope) + '_' + str(args.threshold) +
                    '_' + str(args.weak_mem) + '_' + args.reset_mechanism + str(limit) + '_' +
                    args.optimizer + str(args.lr) + '_' + args.appendix)
# ==============================================================================================
print(prefix_save_name)
model_save_name = prefix_save_name + '_state_dict.pth'
if os.path.exists(model_save_name):
    print('Use trained model.')
    net.load_state_dict(torch.load(model_save_name, map_location=device))
    modules.reset_model_state()
else:
    print("Training start from scratch.")

net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net).to(device)
net = DistributedDataParallel(net, device_ids=[device_ids[args.local_rank]],
                              output_device=device_ids[args.local_rank], broadcast_buffers=True)

if args.local_rank == 0:
    for name, elem in net.module.named_buffers():
        if name.find('slope') >= 0:
            print(f"{name}: {elem}")

    actual_epoch = 0

    train_loss_save_name = prefix_save_name + '_train_loss.npy'
    train_acc_save_name = prefix_save_name + '_train_acc.npy'
    test_acc_save_name = prefix_save_name + '_test_acc.npy'
    train_loss_recorder = []
    if os.path.exists(train_loss_save_name):
        train_loss_recorder = np.load(train_loss_save_name).tolist()
    train_acc_recorder = []
    if os.path.exists(train_acc_save_name):
        train_acc_recorder = np.load(train_acc_save_name).tolist()
    test_acc_recorder = []
    print(type(test_acc_recorder))
    if os.path.exists(test_acc_save_name):
        print("In Path")
        test_acc_recorder = np.load(test_acc_save_name).tolist()
    print(type(test_acc_recorder))

    fire_recorder = np.empty((args.num_epoch, recorder_size))
    old_fire_recorder = None
    train_fire_save_name = prefix_save_name + '_train_fire.npy'
    if os.path.exists(train_fire_save_name):
        old_fire_recorder = np.load(train_fire_save_name)
        actual_epoch = old_fire_recorder.shape[0]

    slope_recorder = np.empty((args.num_epoch, recorder_size))
    old_slope_recorder = None
    train_slope_save_name = prefix_save_name + '_train_slope.npy'
    if os.path.exists(train_slope_save_name):
        old_slope_recorder = np.load(train_slope_save_name)

    mean_recorder = np.empty((args.num_epoch, recorder_size))
    old_mean_recorder = None
    train_mean_save_name = prefix_save_name + '_train_mean.npy'
    if os.path.exists(train_mean_save_name):
        old_mean_recorder = np.load(train_mean_save_name)

    var_recorder = np.empty((args.num_epoch, recorder_size))
    old_var_recorder = None
    train_var_save_name = prefix_save_name + '_train_var.npy'
    if os.path.exists(train_var_save_name):
        old_var_recorder = np.load(train_var_save_name)

# ======================================= Setting =========================================
best_acc = args.best_acc
best_epoch = 0

len_train_loader = len(train_loader)
len_test_loader = len(test_loader)
len_train_data = len_train_loader * args.batch_size
len_test_data = sampler_test.num_samples
print(f"{args.local_rank}, len_train_loader:{len_train_loader}, len_test_loader:{len_test_loader},"
      f" len_train_data:{len_train_data}, len_train_samples:{sampler_train.num_samples}, "
      f"len_test_data:{len(test_data)}, len_test_samples:{len_test_data}")

criterion = nn.NLLLoss()
if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
elif args.optimizer == 'adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise ValueError("Can't find choices in args.optimizer.")

# ========================================= Train ==========================================
adjust.reset()

for epoch in range(args.num_epoch):
    sampler_train.set_epoch(epoch)
    updater.lr_update(epoch, optimizer)
    print(optimizer)

    net.train()
    total_train_loss = 0
    total_train_cor = 0

    for idx, (data, target) in enumerate(train_loader):
        if idx == len_train_loader - 1:
            adjust.start()
        data = utils.fold(code_method(data, args.num_step)).to(device)
        target = target.to(device)
        train_loss, train_cor = back_method(net, args.num_step, data, target, optimizer,
                                            criterion, pre_process)

        total_train_cor += train_cor  # can collect device

        if args.local_rank == 0:
            """
            Only device 0 calculate loss and output information
            """
            total_train_loss += train_loss
            print(f"\rEpoch: {epoch + 1}/{args.num_epoch}, Batch: {idx + 1}/{len_train_loader}, "
                  f"Train loss: {train_loss}", end='')

        if idx == len_train_loader - 1:
            if args.local_rank == 0:
                print()
                # Can collect all device, write a new function
                backprop.store_info(args.batch_size, args.num_step, epoch, fire_recorder,
                                    slope_recorder, mean_recorder, var_recorder)
                actual_epoch += 1
            adjust.close()

    net.eval()
    with torch.no_grad():
        test_cor = 0
        for (data, target) in test_loader:
            data = utils.fold(code_method(data, args.num_step)).to(device)
            target = target.to(device)
            test_cor += backprop.test(net, args.num_step, data, target)  # Can collect device

    test_acc = 100 * test_cor / len_test_data
    test_acc = torch.tensor(test_acc, dtype=torch.float, device=device)
    print(f"{args.local_rank}'s result: {test_cor}, {test_acc}")
    dist.barrier()
    dist.reduce(test_acc, dst=0)
    print(f"{args.local_rank}'s result: {test_cor}, {test_acc}")

    if args.local_rank == 0:
        train_loss_recorder.append(total_train_loss)
        train_acc = 100 * total_train_cor / len_train_data
        train_acc_recorder.append(train_acc)

        test_acc = (test_acc / gpu_num).item()
        print("Final write test_acc in zero device:", test_acc)
        test_acc_recorder.append(test_acc)

        if best_acc < test_acc:
            best_acc = test_acc
            best_epoch = epoch + 1
            torch.save(net.module.state_dict(), model_save_name)
        print(f"Current train acc: {train_acc}, Current test acc: {test_acc}, Best test acc: "
              f"{best_acc}")

    adjust.step(epoch)
    """
    Step and reset separate can support synchronize slope before final update_grad
    """
    modules.reset_model_state()
    adjust.reset()

    if args.local_rank == 0:
        print(f"Best test acc: {best_acc}, Epoch: {best_epoch}, actual_epoch: {actual_epoch}")

        np.save(train_loss_save_name, train_loss_recorder)
        np.save(train_acc_save_name, train_acc_recorder)
        np.save(test_acc_save_name, test_acc_recorder)

        if old_fire_recorder is not None:
            save_fire_recorder = np.concatenate((old_fire_recorder, fire_recorder), axis=0)
        else:
            save_fire_recorder = fire_recorder
        np.save(train_fire_save_name, save_fire_recorder[: actual_epoch])

        if old_slope_recorder is not None:
            save_slope_recorder = np.concatenate((old_slope_recorder, slope_recorder), axis=0)
        else:
            save_slope_recorder = slope_recorder
        np.save(train_slope_save_name, save_slope_recorder[: actual_epoch])

        if old_mean_recorder is not None:
            save_mean_recorder = np.concatenate((old_mean_recorder, mean_recorder), axis=0)
        else:
            save_mean_recorder = mean_recorder
        np.save(train_mean_save_name, save_mean_recorder[: actual_epoch])

        if old_var_recorder is not None:
            save_var_recorder = np.concatenate((old_var_recorder, var_recorder),
                                              axis=0)
        else:
            save_var_recorder = var_recorder
        np.save(train_var_save_name, save_var_recorder[: actual_epoch])

        end_time = time.time()
        print(f"Total use: {(end_time - start_time) / 60} minutes.")
