import argparse
import os

import numpy as np
import torch

from approaches.tag.models import AlexNet, MLP, ResNet18, ResNet18_CUB


# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


def apply_mask(mem_y, out, n_classes):
    """
    Apply mask on the predicted outputs based on the given task - assuming task-incremental learning setup
    :param mem_y: Actual labels
    :param out: Predictions
    :param n_classes: Number of classes per task
    :return: Masked predictions
    """
    for i, y in enumerate(mem_y):
        mask = torch.zeros_like(out[i])
        mask[y - (y % n_classes):y - (y % n_classes) + n_classes] = 1
        out[i] = out[i].masked_fill(mask == 0, -1e10)
    return out


def parse_arguments():
    """
    Parse and print the relevant arguments
    """
    parser = argparse.ArgumentParser(description='Argument parser')
    parser.add_argument('--tasks', default=5, type=int, help='total number of tasks')
    parser.add_argument('--epochs-per-task', default=1, type=int, help='epochs per task')
    parser.add_argument('--dataset', default='rot-mnist', type=str, help='dataset. options: rot-mnist, perm-mnist, cifar100')
    parser.add_argument('--batch-size', default=10, type=int, help='batch-size')
    parser.add_argument('--opt', default='', type=str, help='Manual adagrad')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--gamma', default=0.4, type=float, help='learning rate decay. Use 1.0 for no decay')
    parser.add_argument('--dropout', default=0.25, type=float, help='dropout probability. Use 0.0 for no dropout')
    parser.add_argument('--hiddens', default=256, type=int, help='num of hidden neurons in each layer of a 2-layer MLP')
    parser.add_argument('--compute-eigenspectrum', default=False, type=bool, help='compute eigenvalues/eigenvectors?')
    parser.add_argument('--b', default=None, type=int, help='b')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--runs', default=3, type=int, help='# runs')
    parser.add_argument('--hyp-gs', default='', type=str, help='Hyperparameter search')
    parser.add_argument('--tag-opt', default='adam', type=str, help='tag opt')
    parser.add_argument('--mem-size', default=1, type=int, help='mem')
    parser.add_argument('--multi', default=0, type=int, help='MTL')
    parser.add_argument('--lambd', default=1, type=int, help='EWC')

    args = parser.parse_args()
    print("Parameters:\n  benchmark=" + str(args.dataset) + "\n  num_tasks=" + str(args.tasks) + "\n  " +
          "runs=" + str(args.runs) + "\n  epochs-per-task=" + str(args.epochs_per_task) +
          "\n  batch_size=" + str(args.batch_size) + "\n  " + "learning_rate=" + str(args.lr) + "\n  learning rate decay(gamma)="
          + str(args.gamma) + "\n  dropout prob=" + str(args.dropout) + "\n  optimizer opt="
          + (str(args.opt) if args.opt != '' else 'sgd'))
    if args.opt == '':
        return args
    if 'er' in args.opt or 'agem' in args.opt:
        print("  mem=" + str(args.mem_size))
    if 'tag' in args.opt:
        print("  tag-opt=" + str(args.tag_opt))
        print("  b=" + str(args.b))
    if 'ewc' in args.opt:
        print("  lambda=" + str(args.lambd))
    return args


def set_seeds(seed):
    """
    Set seeds for reproducibility
    """
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def init_experiment(args):
    """
    Initialize the experiment and logging dictionaries
    """
    print('\n------------------- Experiment-' + str(args.seed) + ' started -----------------')
    set_seeds(args.seed)
    loss_db = {t: [0 for i in range(args.tasks)] for t in range(1, args.tasks + 1)}
    acc_db = {t: [0 for i in range(args.tasks)] for t in range(1, args.tasks + 1)}
    return acc_db, loss_db


def end_experiment(args, acc_db, loss_db, get_bwt=False):
    """
    Print the final metrics
    """
    score = np.mean([acc_db[i][-1] for i in acc_db.keys()])
    forget = np.mean([max(acc_db[i]) - acc_db[i][-1] for i in range(1, args.tasks)]) / 100.0
    la = np.mean([acc_db[i][i - 1] for i in acc_db.keys()])
    print('average accuracy = {}, forget = {}, learning accuracy = {}'.format(score, forget, la))
    if get_bwt:
        bwt = np.mean([acc_db[i][-1] - acc_db[i][i - 1] for i in acc_db.keys()])
        print('backward transfer = {}'.format(bwt))
    return score, forget, la


def get_benchmark_model(args):
    """
    Return the corresponding PyTorch model for experiment
    """
    DEVICE = args.device
    raise NotImplementedError

    if 'mnist' in args.dataset:
        if args.tasks == 20 and args.hiddens < 256:
            print("Warning! the main paper MLP with 256 neurons for experiment with 20 tasks")
        return MLP(args.hiddens, {'dropout': args.dropout, 'total_classes': args.tasks * 10, 'classes': 10}).to(DEVICE)
    elif 'cifar' in args.dataset:
        if args.tasks == 10:
            return AlexNet(config={'input_size': (3, 32, 32), 'total_classes': 100, 'classes': int(100 / args.tasks)}).to(DEVICE)
        return ResNet18(config={'input_size': (3, 32, 32), 'dropout': args.dropout, 'classes': int(100 / args.tasks)}).to(DEVICE)
    elif 'mini_imagenet' in args.dataset:
        return ResNet18(config={'input_size': (3, 84, 84), 'dropout': args.dropout, 'classes': int(100 / args.tasks)}).to(DEVICE)
    elif 'cub' in args.dataset:
        return ResNet18_CUB(config={'input_size': (3, 224, 224), 'dropout': args.dropout, 'classes': int(200 / args.tasks)}).to(DEVICE)
    elif '5data' in args.dataset:
        return ResNet18(config={'input_size': (3, 32, 32), 'dropout': args.dropout, 'classes': int(50 / args.tasks)}).to(DEVICE)
    else:
        raise Exception("Unknown dataset.\n" + "The code supports 'perm-mnist, rot-mnist, and cifar-100.")


def log_metrics(metrics, time, task_id, acc_db, loss_db, p=False):
    """
    Log accuracy and loss at different times of training
    """
    if p:
        print('epoch {}, task:{}, metrics: {}'.format(time, task_id, metrics))
    # log to db
    acc = metrics['accuracy']
    loss = metrics['loss']
    loss_db[task_id][time - 1] = loss
    acc_db[task_id][time - 1] = acc
    return acc_db, loss_db


def print_details(tag, prev_task_id, metrics, alpha):
    """
    Print test accuracy on past task datasets (along with )
    """
    if tag:
        print('\tPrev Task:', prev_task_id, ' \tManual\t', metrics['accuracy'], '   \t', alpha[prev_task_id - 1])
    else:
        print('\tPrev Task:', prev_task_id, ' \tManual\t', metrics['accuracy'])


def hyp_lr(args, avg_runs_exp):
    """
    Grid search over Learning rate
    """
    lrs = (0.1, 0.05, 0.01, 0.001)
    best_hyp, best_acc = 0, 0
    for lr in lrs:
        args.lr = lr
        print(args.lr)
        args.seed = 0
        acc = avg_runs_exp(args.runs, validate=True)
        if acc > best_acc:
            best_acc = acc
            best_hyp = lr
    print('Best [lr]:', best_hyp)


def hyp_tag(args, avg_runs_exp):
    """
    Grid search for TAG: Learning rate and b
    """
    bs = (5)
    lrs = (0.005, 0.001, 0.0005, 0.0001, 0.00005, 0.00001)
    best_hyp, best_acc = 0, 0
    for b in bs:
        args.b = b
        for lr in lrs:
            args.lr = lr
            print(args.lr, args.b)
            args.seed = 0
            acc = avg_runs_exp(args.runs, validate=True)
            if acc > best_acc:
                best_acc = acc
                best_hyp = [lr, b]
    print('Best [lr, b]:', best_hyp)


def hyp_ewc(args, avg_runs_exp):
    """
    Grid search for EWC: Learning rate and Lambda
    """
    ls, lambdas = (0.1, 0.05, 0.01, 0.001), (1, 10, 100)
    best_hyp, best_acc = 0, 0
    for l in ls:
        args.lr = l
        for b in lambdas:
            args.lambd = b
            print(l, b)
            args.seed = 0
            acc = avg_runs_exp(args.runs, validate=True)
            if acc > best_acc:
                best_acc = acc
                best_hyp = [l, b]
    print('Best [lr, lambda]:', best_hyp)


def hyp_stable(args, avg_runs_exp):
    """
    Grid search for Stable SGD: Learning rate, decay and Dropout
    """
    dropouts = (0.0, 0.1, 0.25, 0.5)
    lrs = (0.1, 0.05, 0.01)
    decays = (0.9, 0.8, 0.7)
    best_hyp, best_acc = 0, 0
    for dropout in dropouts:
        args.dropout = dropout
        for lr in lrs:
            args.lr = lr
            for b in decays:
                args.gamma = b
                print(dropout, lr, b)
                args.seed = 0
                acc = avg_runs_exp(args.runs, validate=True)
                if acc > best_acc:
                    best_acc = acc
                    best_hyp = [dropout, lr, b]
    print('Best [dropout, lr, decay]:', best_hyp)
