import os, sys, glob
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import numpy as np
import time
import global_v as glv
import load_data
import logging
import argparse
from torch.autograd import Variable
import torchvision

from parser import parse
from architect import Architect
import loss_f
from network import RSNN
from utils import *


def train_arch(train_loader, val_loader, model, architect, eta, optimizer_search):
    total_num = 0
    correct_num = 0
    total_samples = 0
    total_loss = 0
    val_loader_iterator = iter(val_loader)

    for batch_idx, (inputs, labels) in enumerate(train_loader):

        inputs = Variable(inputs, requires_grad=False).cuda()
        labels = Variable(labels, requires_grad=False).cuda()

        inputs_val, labels_val = next(val_loader_iterator)
        inputs_val = Variable(inputs_val, requires_grad=False).cuda()
        labels_val = Variable(labels_val, requires_grad=False).cuda()

        architect.step(inputs, labels, inputs_val, labels_val, eta)

        optimizer_search.zero_grad()
        loss, outputs, correct, total = model.model_loss(inputs, labels)

        loss.backward()
        nn.utils.clip_grad_norm_(model.get_parameters(), network_config['grad_clip'])
        optimizer_search.step()

        total_loss += loss.item()

        total_num += total
        correct_num += correct 

        acc = correct_num * 100 / total_num
        loss_avg = total_loss / total_num

        total_samples += int(labels.shape[0])
        
        logging.info('| train_arch | {:5d}/{:5d} samples | train_acc {:3.2f}% | train_loss {:.6f}'.format(total_samples, len(train_loader.dataset), acc, loss_avg))
        print('\033[2A')
    return acc, loss_avg 

def test(test_loader, model):
    total_num = 0
    correct_num = 0
    total_samples = 0
    total_loss = 0
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs = Variable(inputs, requires_grad=False).cuda()
            labels = Variable(labels, requires_grad=False).cuda()

            loss, outputs, correct, total = model.model_loss(inputs, labels)

            total_num += total 
            correct_num += correct

            total_loss += loss.item()

            acc = correct_num * 100 / total_num
            loss_avg = total_loss / total_num
            total_samples += int(labels.shape[0])
            logging.info('| test | {:5d}/{:5d} samples | test_acc {:3.2f}% | test_loss {:.6f}'.format(total_samples, len(test_loader.dataset), acc, loss_avg))
            print('\033[2A')
    return acc, loss_avg

def train(trainloader, model, optimizer):
    total_num = 0
    correct_num = 0
    total_samples = 0
    total_loss = 0

    for batch_idx, (inputs, labels) in enumerate(trainloader):
        inputs = Variable(inputs, requires_grad=False).cuda()
        labels = Variable(labels, requires_grad=False).cuda()

        optimizer.zero_grad()
        loss, outputs, correct, total = model.model_loss(inputs, labels)

        loss.backward()
        nn.utils.clip_grad_norm_(model.get_parameters(), network_config['grad_clip'])
        
        optimizer.step()

        total_loss += loss.item()

        total_num += total
        correct_num += correct

        acc = correct_num * 100 / total_num
        loss_avg = total_loss / total_num

        total_samples += int(labels.shape[0])

        logging.info('| train_model | {:5d}/{:5d} samples | train_acc {:3.2f}% | train_loss {:.6f}'.format(total_samples, len(train_loader.dataset), acc, loss_avg))
        print('\033[2A')
    return acc, loss_avg

def sequential_search(train_loader, val_loader, model, architect, network_config, epochs, search_lr):
    optimizer_search = torch.optim.Adam(model.get_parameters(), lr=network_config['lr'], weight_decay=network_config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_search, epochs, eta_min=network_config['lr_min'])
    architect.init_arch_optimizer(search_lr)

    best_train_arch = 0
    best_val_arch = 0

    for epoch in range(epochs):
        eta = scheduler.get_last_lr()[0]

        epoch_start_time = time.time()

        # train
        train_acc, train_loss = train_arch(train_loader, val_loader, model, architect, eta, optimizer_search)

        best_train_arch = max(best_train_arch, train_acc)
        logging.info('-' * 96)
        logging.info('train_arch | end of epoch {:3d} | time: {:5.2f}s | train acc {:3.2f}% ({:3.2f}%) | train loss {:.6f} | '.format(epoch, (time.time() - epoch_start_time), train_acc, best_train_arch, train_loss))
        logging.info('-' * 96)

        epoch_start_time = time.time()

        # validation

        val_acc, val_loss = test(val_loader, model)

        if val_acc > best_val_arch:
            model.record_best_arch()
            best_val_arch = val_acc

        logging.info('-' * 96)
        logging.info('val_arch | end of epoch {:3d} | time: {:5.2f}s | val acc {:3.2f}% ({:3.2f}%) | val loss {:.6f} | '.format(epoch, (time.time() - epoch_start_time), val_acc, best_val_arch, val_loss))
        logging.info('-' * 96)

        scheduler.step()


if __name__ == '__main__':
    # initialize parser
    parser = argparse.ArgumentParser(description = "Neural architecture search for recurrent spiking neural networks")
    parser.add_argument('-config', action='store', help='The path of config file')
    parser.add_argument('-save_path', type=str, default=None, help='The path to save model and log')
    parser.add_argument('-gpu', type=int, default=0, help='GPU device to use (default: 0)')
    parser.add_argument('-seed', type=int, default=0, help='random seed (default: time)')
    parser.add_argument('-skip_search', action='store_true', help='tune weights without searching architect (default: False)')
    parser.add_argument('-arch_path', type=str, default=None, help='structure path')
    try:
        args = parser.parse_args()
    except:
        parser.print_help()
        exit(0)

    if args.config is None:
        raise Exception('Unrecognized config file.')
    else:
        config_path = args.config

    # load config file
    params = parse(config_path)
    network_config = params['Network']
    layer_config = params['Layer']

    # create saving directory
    if args.save_path is None:
        args.save_path = 'search-{}-{}'.format(args.save_path, time.strftime("%Y%m%d-%H%M%S"))
    create_exp_dir(args.save_path, scripts_to_save=glob.glob('*.py'), config_file=config_path)

    # initalize logging
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info('Args: {}'.format(args))

    # check GPU
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # set GPU
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.enabled = True
    if args.seed == 0:
        torch.cuda.manual_seed_all(int(time.time()))
        np.random.seed(int(time.time()))
    else:
        torch.cuda.manual_seed_all(args.seed)
        np.random.seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)

    # load data
    train_set, arch_train_set, arch_val_set, test_set = load_data.loader(network_config)

    # init loss function
    criterion = loss_f.SpikeLoss(network_config).cuda()

    model = RSNN(network_config, layer_config, criterion).cuda()

    # init architect
    architect = Architect(model, network_config)

    # init global variables
    glv.init(network_config['n_steps'], network_config['tau_s'])

    train_loader = torch.utils.data.DataLoader(arch_train_set, batch_size=network_config['arch_batch_size'], shuffle=True, num_workers=4)
    val_loader = torch.utils.data.DataLoader(arch_val_set, batch_size=network_config['arch_batch_size'], shuffle=False, num_workers=4)

    # search architect
    if network_config['cell_size']:
        logging.info('search cell size')
        model.set_mode("search_cell_size")
        model.init_weights()
        total_params = sum(x.data.nelement() for x in model.get_parameters())
        logging.info('Model total parameters: {}'.format(total_params))

        epochs = network_config['cell_size_epochs'] 
        search_lr = network_config['cell_size_lr'] 

        sequential_search(train_loader, val_loader, model, architect, network_config, epochs, search_lr)
    if network_config['connection_type']:

        logging.info('search connection type')
        model.set_mode("search_connection_type")
        model.init_weights()
        total_params = sum(x.data.nelement() for x in model.get_parameters())
        logging.info('Model total parameters: {}'.format(total_params))

        epochs = network_config['connection_type_epochs'] 
        search_lr = network_config['connection_type_lr'] 

        sequential_search(train_loader, val_loader, model, architect, network_config, epochs, search_lr)

    best_train = 0
    best_test = 0
        
    logging.info('start weight training')
    model.set_mode("finetune")
    model.init_weights()
    total_params = sum(x.data.nelement() for x in model.get_parameters())
    logging.info('Model total parameters: {}'.format(total_params))
   
    optimizer_finetune = torch.optim.Adam(model.get_parameters(), lr=network_config['lr'], weight_decay=network_config['weight_decay'])
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=network_config['batch_size'], shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=network_config['batch_size'], shuffle=False, num_workers=4)
    for epoch in range(network_config['epochs']):
        epoch_start_time = time.time()
    
        # train
        train_acc, train_loss = train(train_loader, model, optimizer_finetune)
        best_train = max(best_train, train_acc)
        logging.info('-' * 96)
        logging.info('train_model | end of epoch {:3d} | time: {:5.2f}s | train acc {:3.2f}% ({:3.2f}%) | train loss {:.6f} | '.format(epoch, (time.time() - epoch_start_time), train_acc, best_train, train_loss))
        logging.info('-' * 96)
        
        
        epoch_start_time = time.time()
        # test
        test_acc, test_loss = test(test_loader, model)
        best_test = max(best_test, test_acc)
        logging.info('-' * 96)
        logging.info('test_model | end of epoch {:3d} | time: {:5.2f}s | test acc {:3.2f}% ({:3.2f}%) | test loss {:.6f} | '.format(epoch, (time.time() - epoch_start_time), test_acc, best_test, test_loss))
        logging.info('-' * 96)
        
    logging.info('end weight training')
    
