#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

from data.core50_data_loader import CORE50
import copy
import os
import json
from models.mobilenet import MyMobilenetV1
from models import replace_bn_with_brn, change_brn_pars, change_bn_momentum, freeze_conv_dw
from utils.utils import *
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# set cuda device (based on your hardware)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_ar1_no_replay_nc.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep = eval(exp_config['init_train_ep'])
inc_train_ep = eval(exp_config['inc_train_ep'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_sz = eval(exp_config['rm_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
ewc_lambda = eval(exp_config['ewc_lambda'])

# setting up log dir for tensorboard
log_dir = 'logs_NC/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)

n_runs = 1

test_acc_results = []

for _ in range(n_runs):
    test_acc_results.append({})

for run in range(n_runs):
    # Other variables init
    tot_it_step = 0
    rm = None

    # Create the dataset object
    dataset = CORE50(root='your/path/here', scenario="nc", run=2)
    preproc = preprocess_imgs

    # Get the fixed test set
    test_x, test_y = dataset.get_test_set()

    # Model setup
    model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num)

    change_bn_momentum(model, momentum=0.001)

    model.saved_weights = {}
    model.past_j = {i: 0 for i in range(50)}
    model.cur_j = {i: 0 for i in range(50)}
    if ewc_lambda != 0:
        ewcData, synData = create_syn_data(model)

    # Optimizer setup
    optimizer = torch.optim.SGD(
        filter(lambda par: par.requires_grad, model.parameters()), lr=init_lr, momentum=momentum, weight_decay=l2
    )
    criterion = torch.nn.CrossEntropyLoss()

    # --------------------------------- Training -----------------------------------

    # loop over the training incremental batches
    for i, train_batch in enumerate(dataset):

        if ewc_lambda != 0:
            init_batch(model, ewcData, synData)

        if i == 1:
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, model.parameters()), lr=inc_lr, momentum=momentum, weight_decay=l2
            )

        train_x, train_y = train_batch
        train_x = preproc(train_x)

        cur_class = [int(o) for o in set(train_y)]
        model.cur_j = examples_per_class(train_y)

        print("----------- batch {0} -------------".format(i))
        print("train_x shape: {}, train_y shape: {}"
              .format(train_x.shape, train_y.shape))

        model.train()

        reset_weights(model, cur_class)
        cur_ep = 0

        (train_x, train_y), it_x_ep = pad_data([train_x, train_y], mb_size)
        shuffle_in_unison([train_x, train_y], in_place=True)

        model = maybe_cuda(model, use_cuda=use_cuda)
        acc = None
        ave_loss = 0

        train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
        train_y = torch.from_numpy(train_y).type(torch.LongTensor)

        if i == 0:
            train_ep = init_train_ep
        else:
            train_ep = inc_train_ep

        for ep in range(train_ep):

            print("training ep: ", ep)
            correct_cnt, ave_loss = 0, 0

            if i > 0:
                cur_sz = train_x.size(0) // ((train_x.size(0) + rm_sz) // mb_size)
                it_x_ep = train_x.size(0) // cur_sz
                n2inject = max(0, mb_size - cur_sz)
            else:
                n2inject = 0
            print("total sz:", train_x.size(0) + rm_sz)
            print("n2inject", n2inject)
            print("it x ep: ", it_x_ep)

            for it in range(it_x_ep):

                if ewc_lambda != 0:
                    pre_update(model, synData)

                start = it * (mb_size - n2inject)
                end = (it + 1) * (mb_size - n2inject)

                optimizer.zero_grad()

                x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)

                lat_mb_x = None
                y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)


                logits = model(x_mb)

                _, pred_label = torch.max(logits, 1)
                correct_cnt += (pred_label == y_mb).sum()

                loss = criterion(logits, y_mb)

                if ewc_lambda != 0:
                    loss+= compute_ewc_loss(model, ewcData, lambd=ewc_lambda)

                ave_loss += loss.item()

                loss.backward()

                optimizer.step()

                if ewc_lambda != 0:
                    post_update(model, synData)

                acc = correct_cnt.item() / ((it + 1) * y_mb.size(0))
                ave_loss /= ((it + 1) * y_mb.size(0))

                if it % 10 == 0:
                    print(
                        '==>>> it: {}, avg. loss: {:.6f}, '
                        'running train acc: {:.3f}'
                            .format(it, ave_loss, acc)
                    )

                # Log scalar values (scalar summary) to TB
                tot_it_step += 1
                writer.add_scalar('classifier_train_loss', ave_loss, tot_it_step)
                writer.add_scalar('classifier_train_accuracy', acc, tot_it_step)

            cur_ep += 1

        consolidate_weights(model, cur_class, cur_class)
        if ewc_lambda != 0:
            update_ewc_data(model, ewcData, synData, clip_to=0.001, c=0.000001)

        set_consolidate_weights(model)
        ave_loss, acc, accs, cf = get_accuracy_conf_matrix(
            model, criterion, mb_size, test_x, test_y, preproc=preproc
        )

        # Log scalar values (scalar summary) to TB
        writer.add_scalar('classifier_test_loss', ave_loss, i)
        writer.add_scalar('classifier_test_accuracy', acc, i)
        test_acc_results[run][i] = acc

        # update number examples encountered over time
        for c, n in model.cur_j.items():
            model.past_j[c] += n

        print("---------------------------------")
        print("Accuracy: ", acc)
        print("---------------------------------")

        weight_stats(model, ewcData, clip_to=0.001)

writer.close()
