import logging
import math

import models
from dataset_utils import STL10_STR
from models.res_adapt import ResNet18_adapt
from train_1st_order import loss_compute
from train_simple_model import evaluate_dataloader
from utils import *
from args import parse_ho_args, dump_args_dict, load_args_dict
from datasets import make_reproducible_dataset


def trainer(args, model, trainloader, epoch_id, criterion, optimizer, scheduler, logfile, num_classes):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    print_and_save('\nTraining Epoch: [%d | %d] LR: %f' % (epoch_id + 1, args.epochs, scheduler.get_last_lr()[-1]),
                   logfile)
    for batch_idx, (inputs, targets) in enumerate(trainloader):

        inputs, targets = inputs.to(args.device), targets.to(args.device)

        model.train()
        outputs = model(inputs)

        if args.sep_decay:
            loss = loss_compute(args, model, criterion, outputs, targets)
        else:
            if args.loss in [CROSS_ENTROPY_TAG, LABEL_SMOOTHING_TAG, LABEL_RELAXATION_TAG]:
                loss = criterion(outputs[0], targets)
            elif args.loss == MSE_TAG:
                loss = criterion(outputs[0], nn.functional.one_hot(targets, num_classes).type(torch.FloatTensor).to(args.device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        model.eval()
        outputs = model(inputs)
        prec1, prec5 = compute_accuracy(outputs[0].detach().data, targets.detach().data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        if batch_idx % 10 == 0:
            print_and_save('[epoch: %d] (%d/%d) | Loss: %.4f | top1: %.4f | top5: %.4f ' %
                           (epoch_id + 1, batch_idx + 1, len(trainloader), losses.avg, top1.avg, top5.avg), logfile)

    scheduler.step()

    return losses.avg


def conduct_run(args, trainloader, test_val_loader, num_classes, val_prefix, save_model=False, test_run=False):
    if args.model == "MLP":
        input_dim = 3072
        if args.dataset == STL10_STR:
            input_dim = 3 * 96 * 96

        model = models.__dict__[args.model](hidden=args.width, depth=args.depth,
                                                      fc_bias=args.bias, num_classes=num_classes,
                                                      input_dim=input_dim).to(
            args.device)
    elif args.model == "ResNet18_adapt":
        model = ResNet18_adapt(width=args.width, num_classes=num_classes,
                               fc_bias=args.bias).to(args.device)
    else:
        model = models.__dict__[args.model](num_classes=num_classes, fc_bias=args.bias,
                                                      ETF_fc=args.ETF_fc,
                                                      fixdim=args.fixdim, SOTA=args.SOTA).to(
            args.device)

    # summary(model, input_size=(3, 32, 32), batch_size=1)

    criterion = make_criterion(args, num_classes)
    optimizer = make_optimizer(args, model)
    scheduler = make_scheduler(args, optimizer)

    logfile = open('%s/train_log.txt' % (args.save_path), 'w')

    if os.path.exists(
            os.path.join(args.save_path, "epoch_" + str(args.epochs).zfill(3) + ".pth")) and not args.force_retrain:
        logging.info("Model already exists, loading this model...")
        model.load_state_dict(torch.load(os.path.join(args.save_path, "epoch_" + str(args.epochs).zfill(3) + ".pth")))
    else:
        print_and_save('# of model parameters: ' + str(count_network_parameters(model)), logfile)
        print_and_save('--------------------- Training -------------------------------', logfile)

        for epoch_id in range(args.epochs):
            loss = trainer(args, model, trainloader, epoch_id, criterion, optimizer, scheduler, logfile,
                    num_classes=num_classes)
            if math.isnan(loss):
                print_and_save('NaN loss encountered, stopping training...', logfile)
                break

            if save_model:
                torch.save(model.state_dict(),
                           os.path.join(args.save_path, "epoch_" + str(epoch_id + 1).zfill(3) + ".pth"))

    test_val_acc = evaluate_dataloader(test_val_loader, model, args, criterion, logfile, is_binary=False,
                                       prefix=val_prefix, num_classes=num_classes, return_top1=True)

    logfile.close()

    acc_str = "{}_acc".format(val_prefix)
    if not test_run:
        reported_metrics = {acc_str: test_val_acc}
    logging.info("{}: {}".format(acc_str, test_val_acc))


def main():
    args, config = parse_ho_args(return_config=True)
    assert args.use_ho_uid is True, "This script is only for HO experiments - mark the UID by the respective flag."
    assert args.ho is not None and args.ho == "bayes_opt", "This script is only for Bayesian Optimization experiments."

    set_seed(manualSeed=args.seed)

    device = torch.device("cuda:" + str(args.gpu_id) if torch.cuda.is_available() else "cpu")
    args.device = device

    # Train final model on best trial
    trainloader, _, testloader, num_classes = make_reproducible_dataset(args.dataset, args.data_dir, args.seed,
                                                                        args.save_path, args.batch_size,
                                                                        args.sample_size, SOTA=args.SOTA,
                                                           val_split_prop=None, label_noise=args.label_noise)
    # Retrieve best hyperparameters from HO run
    best_run_uid = "_".join(args.uid.split("_")[:-1])
    best_run_args_path = os.path.join(os.path.join(config["PATHS"]["HO_EXP_PATH"], best_run_uid), "args.json")
    best_run_args = load_args_dict(best_run_args_path)

    for key in ["lr", "gamma", "ls_alpha", "lr_alpha"]:
        args.__dict__[key] = best_run_args.__dict__[key]

    # Conduct run
    conduct_run(args, trainloader, testloader, num_classes=num_classes, val_prefix="test", save_model=True,
                test_run=True)

    dump_args_dict(args)


if __name__ == "__main__":
    main()
