import argparse
from torchvision.transforms import transforms
import torchvision
import errno
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import os
import numpy as np
import time
import random
import torch.backends.cudnn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
import sys

sys.path.append('../../')
from iclr23code import utils, dataset, backprop, modules, surrogate, LIF


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def _mkdir(path):
    try:
        os.makedirs(path)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise


def _get_cache_path(filepath):
    import hashlib
    h = hashlib.sha1(filepath.encode()).hexdigest()
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
    cache_path = os.path.expanduser(cache_path)
    return cache_path


def load_data(traindir, valdir, cache_dataset, distributed):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Loading training data")
    st = time.time()
    cache_path = _get_cache_path(traindir)
    if cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_train from {}".format(cache_path))
        dataset, _ = torch.load(cache_path)
    else:
        dataset = torchvision.datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        if cache_dataset:
            print("Saving dataset_train to {}".format(cache_path))
            _mkdir(os.path.dirname(cache_path))
            save_on_master((dataset, traindir), cache_path)
    print("Took", time.time() - st)

    print("Loading validation data")
    cache_path = _get_cache_path(valdir)
    if cache_dataset and os.path.exists(cache_path):
        # Attention, as the transforms are also cached!
        print("Loading dataset_test from {}".format(cache_path))
        dataset_test, _ = torch.load(cache_path)
    else:
        dataset_test = torchvision.datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        if cache_dataset:
            print("Saving dataset_test to {}".format(cache_path))
            _mkdir(os.path.dirname(cache_path))
            save_on_master((dataset_test, valdir), cache_path)

    print("Creating data loaders")
    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True


start_time = time.time()

parser = argparse.ArgumentParser()
parser.add_argument('--device_ids', type=str, default='0, 1, 2, 3, 4, 5, 6, 7', help='Training Device')
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter')

parser.add_argument('--data_path', default='../../data', type=str, help='Path of data')
parser.add_argument('--dataset', default='IMAGE', type=str, help='Dataset name',
                    choices=['CIFAR10', 'CIFAR100', 'IMAGE', 'AUGDVS'])
parser.add_argument('--batch_size', default=68, type=int, help='Batch size')

parser.add_argument('--neuron', default='lif', type=str, help='Tyep of neuron',
                    choices=['lif'])
parser.add_argument('--num_step', default=4, type=int, help='Time dimension')
parser.add_argument('--bn', default='tdbn', type=str, help='Type of bn layer in model',
                    choices=['nn', 'straight', 'tdbn'])
parser.add_argument('--spike_func', default='triangle', type=str, help='Surrogate function',
                    choices=['rectangular', 'triangle', 'actan'])
parser.add_argument('--slope', default=1.0, type=float, help='Parameter used in surrogate')
parser.add_argument('--threshold', default=1.0, type=float, help='Membrane threshold')
parser.add_argument('--weak_mem', default=1.0, type=float, help='Weaken rate of mem')
parser.add_argument('--reset_mechanism', default='zero', type=str,
                    help='Membrane reset mechanism', choices=['subtract', 'zero'])

parser.add_argument('--model', default='sewresnet34', type=str, help='Architecture of model',
                    choices=['sewresnet34'])

parser.add_argument('--back_method', default='trainImg', type=str, help='Backward method',
                    choices=['train', 'trainImg'])
parser.add_argument('--pre_process', default='logSoft', type=str, help='Pre process if result',
                    choices=['logSoft', 'soft', 'none'])
parser.add_argument('--adjust', default='bna', type=str, help='adjust class for slope',
                    choices=['d', 'bna'])
parser.add_argument('--key', default='0', type=str, help='idx of inner_list')
parser.add_argument('--momentum', default=0.0, type=float, help='Momentum in adjust class')
parser.add_argument('--limit_k', default=0.2, type=float, help='limit_k in adjust class')

parser.add_argument('--code_method', default='copy', type=str, help='Coding method',
                    choices=['copy', 'dvs'])
parser.add_argument('--updater', default='img', type=str, help='For lr update',
                    choices=['warm', 'img'])

parser.add_argument('--lr', default=1e-1, type=float, help='Max learning rate')
parser.add_argument('--num_epoch', default=160, type=int, help='Training epodh')
parser.add_argument('--best_acc', default=0, type=float, help='Best acc of current model.')
parser.add_argument('--optimizer', default='sgd', type=str,
                    help='If use adam, set a lower lr(0.01)', choices=['adam', 'sgd'])
parser.add_argument('--appendix', default='0', type=str, help='Append information')
parser.add_argument('--workers', default=16, type=int)
args = parser.parse_args()

set_seed((args.local_rank + 1) * 1000)

device_ids = list(map(int, args.device_ids.split(',')))
gpu_num = len(device_ids)
print(f'Current gup size: {gpu_num}')

dist.init_process_group(backend='nccl')
device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))
torch.cuda.set_device(device)

if args.neuron == 'lif':
    neuron = LIF
else:
    raise ValueError("Can't find choice in args.neuron.")

if args.bn == 'nn':
    bn = modules.batch_norm2d
elif args.bn == 'straight':
    bn = modules.straight
elif args.bn == 'tdbn':
    bn = modules.td_bn
else:
    raise ValueError("Can't find choice in args.bn.")

if args.spike_func == 'rectangular':
    spike_func = surrogate.rectangular
elif args.spike_func == 'triangle':
    spike_func = surrogate.triangle
elif args.spike_func == 'actan':
    spike_func = surrogate.actan
else:
    raise ValueError("Can't find choice in args.spike_func.")

if args.model == 'sewresnet34':
    print("Note: If use model `sewresnet34`, remember to use IF neuron(set weak_mem to 1.0).")
    recorder_size = 36
    model = modules.SewResNet34
else:
    raise ValueError("Can't fine choice in args.model.")

net = model(neuron, args.num_step, bn, spike_func, args.slope, args.threshold, args.weak_mem,
            args.reset_mechanism, args.dataset).to(device)

train_dir = 'yourpath/ImageNet/train'
val_dir = 'yourpath/ImageNet/val'
train_data, test_data, sampler_train, sampler_test = load_data(train_dir, val_dir,
                                                               'yourpath/ImageNet/tmp', True)

print(f'dataset_train:{train_data.__len__()}, dataset_test:{test_data.__len__()}')

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size,
    sampler=sampler_train,
    num_workers=args.workers, pin_memory=True,
    drop_last=True,
)

test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.batch_size,
    sampler=sampler_test,
    num_workers=args.workers, pin_memory=True)

# train_data, test_data = dataset.dataload(args.data_path, args.dataset)
# sampler_train = DistributedSampler(train_data)
# train_loader = DataLoader(train_data, batch_size=args.batch_size, sampler=sampler_train,
#                           drop_last=True, pin_memory=True, num_workers=8)
# sampler_test = DistributedSampler(test_data)
# test_loader = DataLoader(test_data, batch_size=args.batch_size, sampler=sampler_test, pin_memory=True,
#                          num_workers=8)

if args.back_method == 'train':
    back_method = backprop.train
elif args.back_method == 'trainImg':
    back_method = backprop.trainImg
else:
    raise ValueError("Can't find choice in args.back_method.")

if args.pre_process == 'None':
    pre_process = None
elif args.pre_process == 'logSoft':
    pre_process = nn.LogSoftmax(dim=-1)
elif args.pre_process == 'soft':
    pre_process = nn.Softmax(dim=-1)
else:
    raise ValueError("Can't find choice in args.pre_process.")

limit = 0
if args.updater == 'warm':
    updater = utils.WarmUpdate(args.num_epoch, args.lr)
elif args.updater == 'img':
    updater = utils.ImgUpdate(args.num_epoch, args.lr)
else:
    raise ValueError("Can't find choice in args.updater.")

if args.adjust == 'd':
    adjust = backprop.Adjust(limit)
elif args.adjust == 'bna':
    adjust = backprop.BNAdjust(limit, args.key, args.momentum, args.limit_k)
else:
    raise ValueError("Can't find choice in args.adjust.")

if args.code_method == 'copy':
    code_method = utils.copy_code
elif args.code_method == 'dvs':
    code_method = utils.dvs_code
else:
    raise ValueError("Can't find choice in args.code_method.")

work_dir = './result/'
os.makedirs(work_dir, exist_ok=True)
prefix_save_name = (work_dir + args.dataset + '_' + args.code_method + '_' + str(args.num_step) +
                    '_' + args.back_method + '_' + args.updater + '_' + args.pre_process + '_' +
                    args.adjust + '_' + args.key + '_' + str(args.momentum) + '_' +
                    str(args.limit_k) + '_' + args.model + '_' + args.neuron + '_' + args.bn +
                    '_' + args.spike_func + '_' + str(args.slope) + '_' + str(args.threshold) +
                    '_' + str(args.weak_mem) + '_' + args.reset_mechanism + str(limit) + '_' +
                    args.optimizer + str(args.lr) + '_' + args.appendix)
# ==============================================================================================
print(prefix_save_name)
model_save_name = prefix_save_name + '_state_dict.pth'
if os.path.exists(model_save_name):
    print('Use trained model.')
    net.load_state_dict(torch.load(model_save_name, map_location=device))
    modules.reset_model_state()
else:
    print("Training start from scratch.")

net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net).to(device)
net = DistributedDataParallel(net, device_ids=[device_ids[args.local_rank]],
                              output_device=device_ids[args.local_rank], broadcast_buffers=True)

if args.local_rank == 0:
    for name, elem in net.module.named_buffers():
        if name.find('slope') >= 0:
            print(f"{name}: {elem}")

    actual_epoch = 0

    train_loss_save_name = prefix_save_name + '_train_loss.npy'
    train_acc_save_name = prefix_save_name + '_train_acc.npy'
    test_acc_save_name = prefix_save_name + '_test_acc.npy'
    train_loss_recorder = []
    if os.path.exists(train_loss_save_name):
        train_loss_recorder = np.load(train_loss_save_name).tolist()
    train_acc_recorder = []
    if os.path.exists(train_acc_save_name):
        train_acc_recorder = np.load(train_acc_save_name).tolist()
    test_acc_recorder = []
    print(type(test_acc_recorder))
    if os.path.exists(test_acc_save_name):
        print("In Path")
        test_acc_recorder = np.load(test_acc_save_name).tolist()
    print(type(test_acc_recorder))

    fire_recorder = np.empty((args.num_epoch, recorder_size))
    old_fire_recorder = None
    train_fire_save_name = prefix_save_name + '_train_fire.npy'
    if os.path.exists(train_fire_save_name):
        old_fire_recorder = np.load(train_fire_save_name)
        actual_epoch = old_fire_recorder.shape[0]

    slope_recorder = np.empty((args.num_epoch, recorder_size))
    old_slope_recorder = None
    train_slope_save_name = prefix_save_name + '_train_slope.npy'
    if os.path.exists(train_slope_save_name):
        old_slope_recorder = np.load(train_slope_save_name)

    mean_recorder = np.empty((args.num_epoch, recorder_size))
    old_mean_recorder = None
    train_mean_save_name = prefix_save_name + '_train_mean.npy'
    if os.path.exists(train_mean_save_name):
        old_mean_recorder = np.load(train_mean_save_name)

    var_recorder = np.empty((args.num_epoch, recorder_size))
    old_var_recorder = None
    train_var_save_name = prefix_save_name + '_train_var.npy'
    if os.path.exists(train_var_save_name):
        old_var_recorder = np.load(train_var_save_name)

# ======================================= Setting =========================================
best_acc = args.best_acc
best_epoch = 0

len_train_loader = len(train_loader)
len_test_loader = len(test_loader)
len_train_data = len_train_loader * args.batch_size
len_test_data = sampler_test.num_samples
print(f"{args.local_rank}, len_train_loader:{len_train_loader}, len_test_loader:{len_test_loader},"
      f" len_train_data:{len_train_data}, len_train_samples:{sampler_train.num_samples}, "
      f"len_test_data:{len(test_data)}, len_test_samples:{len_test_data}")

criterion = nn.NLLLoss()
if args.optimizer == 'sgd':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr)
elif args.optimizer == 'adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise ValueError("Can't find choices in args.optimizer.")

# ========================================= Train ==========================================
adjust.reset()

for epoch in range(args.num_epoch):
    sampler_train.set_epoch(epoch)
    updater.lr_update(epoch, optimizer)
    print(optimizer)

    net.train()
    total_train_loss = 0
    total_train_cor = 0

    for idx, (data, target) in enumerate(train_loader):
        if idx == len_train_loader - 1:
            adjust.start()
        data = utils.fold(code_method(data, args.num_step)).to(device)
        target = target.to(device)
        train_loss, train_cor = back_method(net, args.num_step, data, target, optimizer,
                                            criterion, pre_process)

        total_train_cor += train_cor  # can collect device

        if args.local_rank == 0:
            """
            Only device 0 calculate loss and output information
            """
            total_train_loss += train_loss
            print(f"\rEpoch: {epoch + 1}/{args.num_epoch}, Batch: {idx + 1}/{len_train_loader}, "
                  f"Train loss: {train_loss}", end='')

        if idx == len_train_loader - 1:
            if args.local_rank == 0:
                print()
                # Can collect all device, write a new function
                backprop.store_info(args.batch_size, args.num_step, epoch, fire_recorder,
                                    slope_recorder, mean_recorder, var_recorder)
                actual_epoch += 1
            adjust.close()

    net.eval()
    with torch.no_grad():
        test_cor = 0
        for (data, target) in test_loader:
            data = utils.fold(code_method(data, args.num_step)).to(device)
            target = target.to(device)
            test_cor += backprop.test(net, args.num_step, data, target)  # Can collect device

    test_acc = 100 * test_cor / len_test_data
    test_acc = torch.tensor(test_acc, dtype=torch.float, device=device)
    print(f"{args.local_rank}'s result: {test_cor}, {test_acc}")
    dist.barrier()
    dist.reduce(test_acc, dst=0)
    print(f"{args.local_rank}'s result: {test_cor}, {test_acc}")

    if args.local_rank == 0:
        train_loss_recorder.append(total_train_loss)
        train_acc = 100 * total_train_cor / len_train_data
        train_acc_recorder.append(train_acc)

        test_acc = (test_acc / gpu_num).item()
        print("Final write test_acc in zero device:", test_acc)
        test_acc_recorder.append(test_acc)

        if best_acc < test_acc:
            best_acc = test_acc
            best_epoch = epoch + 1
            torch.save(net.module.state_dict(), model_save_name)
        print(f"Current train acc: {train_acc}, Current test acc: {test_acc}, Best test acc: "
              f"{best_acc}")

    adjust.step(epoch)
    """
    Step and reset separate can support synchronize slope before final update_grad
    """
    modules.reset_model_state()
    adjust.reset()

    if args.local_rank == 0:
        print(f"Best test acc: {best_acc}, Epoch: {best_epoch}, actual_epoch: {actual_epoch}")

        np.save(train_loss_save_name, train_loss_recorder)
        np.save(train_acc_save_name, train_acc_recorder)
        np.save(test_acc_save_name, test_acc_recorder)

        if old_fire_recorder is not None:
            save_fire_recorder = np.concatenate((old_fire_recorder, fire_recorder), axis=0)
        else:
            save_fire_recorder = fire_recorder
        np.save(train_fire_save_name, save_fire_recorder[: actual_epoch])

        if old_slope_recorder is not None:
            save_slope_recorder = np.concatenate((old_slope_recorder, slope_recorder), axis=0)
        else:
            save_slope_recorder = slope_recorder
        np.save(train_slope_save_name, save_slope_recorder[: actual_epoch])

        if old_mean_recorder is not None:
            save_mean_recorder = np.concatenate((old_mean_recorder, mean_recorder), axis=0)
        else:
            save_mean_recorder = mean_recorder
        np.save(train_mean_save_name, save_mean_recorder[: actual_epoch])

        if old_var_recorder is not None:
            save_var_recorder = np.concatenate((old_var_recorder, var_recorder),
                                               axis=0)
        else:
            save_var_recorder = var_recorder
        np.save(train_var_save_name, save_var_recorder[: actual_epoch])

        end_time = time.time()
        print(f"Total use: {(end_time - start_time) / 60} minutes.")
