#!/usr/bin/env python
# -*- coding: utf-8 -*-


# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

from data_loader import CORE50
import os
import json
from utils import *
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from PIL import Image
from torchvision.transforms import ToTensor
from models.mobilenet import MyMobilenetV1

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# set cuda device (based on your hardware)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_no_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep = eval(exp_config['init_train_ep'])
inc_train_ep = eval(exp_config['inc_train_ep'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
reg_lambda = eval(exp_config['reg_lambda'])

# setting up log dir for tensorboard
log_dir = 'logs_NIC/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)

# Other variables init
tot_it_step = 0

# Create the dataset object
dataset = CORE50(root='your/path/here', scenario="nicv2_391", run=0)
preproc = preprocess_imgs

# Get the fixed test set
test_x, test_y = dataset.get_test_set()

# Model setup
model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num)
# we replace BN layers with Batch Renormalization layers
replace_bn_with_brn(
    model, momentum=init_update_rate, r_d_max_inc_step=inc_step,
    max_r_max=max_r_max, max_d_max=max_d_max
)
model.saved_weights = {}
model.past_j = {i: 0 for i in range(50)}
model.cur_j = {i: 0 for i in range(50)}
if reg_lambda != 0:

    ewcData, synData = create_syn_data(model)

# Optimizer setup
optimizer = torch.optim.SGD(
    model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
)
criterion = torch.nn.CrossEntropyLoss()

seen_classes = set()
cf_labels_inc = []
cf_count_inc = 0
for _ in range(50):
    cf_labels_inc.append(-1)

# brn_params_scale = {}

# --------------------------------- Training -----------------------------------

# loop over the training incremental batches
for i, train_batch in enumerate(dataset):

    if reg_lambda != 0:
        init_batch(model, ewcData, synData)

    # freeze brn parameters
    for name, param in model.named_parameters():
        if "bn" in name:
            param.requires_grad = False

    if i == 1:
        change_brn_pars(
            model, momentum=inc_update_rate, r_d_max_inc_step=0,
            r_max=max_r_max, d_max=max_d_max)

        base_params = list(map(lambda p: p[1], filter(lambda kv: kv[0] not in ['output.weight'],
                                                      model.named_parameters())))
        classifier_optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.output.parameters(), 'lr': inc_lr * 10.0}],
            lr=inc_lr, momentum=momentum, weight_decay=l2)

    train_x, train_y = train_batch
    train_x = preproc(train_x)
    seen_classes.update(train_y)

    cur_class = [int(o) for o in set(train_y)]
    model.cur_j = examples_per_class(train_y)

    for cc in cur_class:
        if cf_labels_inc[cc] == -1:
            cf_labels_inc[cc] = cf_count_inc
            cf_count_inc += 1

    print("----------- batch {0} -------------".format(i))
    print("train_x shape: {}, train_y shape: {}"
          .format(train_x.shape, train_y.shape))

    model.train()
    reset_weights(model, cur_class)
    cur_ep = 0


    (train_x, train_y), it_x_ep = pad_data([train_x, train_y], mb_size)
    shuffle_in_unison([train_x, train_y], in_place=True)

    model = maybe_cuda(model, use_cuda=use_cuda)
    acc = None
    ave_loss = 0

    train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
    train_y = torch.from_numpy(train_y).type(torch.LongTensor)

    if i == 0:
        train_ep = init_train_ep
    else:
        train_ep = inc_train_ep

    for ep in range(train_ep):
        correct_cnt, ave_loss = 0, 0

        for it in range(it_x_ep):

            if reg_lambda !=0:
                pre_update(model, synData)

            start = it * mb_size
            end = (it + 1) * mb_size

            optimizer.zero_grad()

            x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)
            y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)

            # if lat_mb_x is not None, this tensor will be concatenated in
            # the forward pass on-the-fly in the latent replay layer
            logits = model(
                x_mb, latent_input=None, return_lat_acts=False)

            _, pred_label = torch.max(logits, 1)
            correct_cnt += (pred_label == y_mb).sum()

            loss = criterion(logits, y_mb)
            if reg_lambda != 0:
                loss += compute_ewc_loss(model, ewcData, lambd=reg_lambda)
            ave_loss += loss.item()

            loss.backward()
            optimizer.step()

            if reg_lambda != 0:
                post_update(model, synData)

            acc = correct_cnt.item() / \
                  ((it + 1) * y_mb.size(0))
            ave_loss /= ((it + 1) * y_mb.size(0))
            # Log scalar values (scalar summary) to TB
            tot_it_step += 1
            writer.add_scalar('classifier_train_loss', ave_loss, tot_it_step)
            writer.add_scalar('classifier_train_accuracy', acc, tot_it_step)

        cur_ep += 1

    consolidate_weights_nic(model, cur_class, cur_class)

    if reg_lambda != 0:
        update_ewc_data(model, ewcData, synData, clip_to=0.001, c=0.00002)

    set_consolidate_weights(model)

    ave_loss, acc, accs, cf = get_accuracy_conf_matrix_nic(
        model, criterion, mb_size, test_x, test_y, preproc=preproc,
        cf_labels_inc=cf_labels_inc
    )

    # Log scalar values (scalar summary) to TB
    writer.add_scalar('classifier_test_loss', ave_loss, i)
    writer.add_scalar('classifier_test_accuracy', acc, i)

    df_cm_retrain = pd.DataFrame(cf, range(50), range(50))
    plot_buf = gen_plot(df_cm_retrain)
    image_retrain = Image.open(plot_buf)
    image_retrain = ToTensor()(image_retrain)
    writer.add_image('Confusion matrix', image_retrain, i)

    # update number examples encountered over time
    for c, n in model.cur_j.items():
        model.past_j[c] += n

    print("---------------------------------")
    print("Accuracy: ", acc)
    print("---------------------------------")

writer.close()
