#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

from data_loader import CORE50
import copy
import os
import json
from utils import *
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from PIL import Image
from torchvision.transforms import ToTensor
from models.mobilenet import MyMobilenetV1

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# set cuda device (based on your hardware)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_ar1_random_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep = eval(exp_config['init_train_ep'])
inc_train_ep = eval(exp_config['inc_train_ep'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_sz = eval(exp_config['rm_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
reg_lambda = eval(exp_config['reg_lambda'])

# setting up log dir for tensorboard
log_dir = 'logs_NIC/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)

# Other variables init
tot_it_step = 0
rm = None

# Create the dataset object
dataset = CORE50(root='your/path/here', scenario="nicv2_391", run=2)
preproc = preprocess_imgs

# Get the fixed test set
test_x, test_y = dataset.get_test_set()

# Model setup
model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num)
# we replace BN layers with Batch Renormalization layers
replace_bn_with_brn(
    model, momentum=init_update_rate, r_d_max_inc_step=inc_step,
    max_r_max=max_r_max, max_d_max=max_d_max
)
freeze_up_to(model, freeze_below_layer, only_conv=False)
model.saved_weights = {}
model.past_j = {i: 0 for i in range(50)}
model.cur_j = {i: 0 for i in range(50)}
if reg_lambda != 0:
    # the regularization is based on Synaptic Intelligence as described in the
    # paper. ewcData is a list of two elements (best parametes, importance)
    # while synData is a dictionary with all the trajectory data needed by SI
    ewcData, synData = create_syn_data(model)

# Optimizer setup
optimizer = torch.optim.SGD(
    model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
)
criterion = torch.nn.CrossEntropyLoss()

seen_classes = set()
cf_labels_inc = []
cf_count_inc = 0
for _ in range(50):
    cf_labels_inc.append(-1)


# --------------------------------- Training -----------------------------------

# loop over the training incremental batches
for i, train_batch in enumerate(dataset):

    if reg_lambda != 0:
        init_batch(model, ewcData, synData)

    # we freeze the layer below the replay layer since the first batch
    freeze_up_to(model, freeze_below_layer, only_conv=False)

    if i == 1:
        change_brn_pars(
            model, momentum=inc_update_rate, r_d_max_inc_step=0,
            r_max=max_r_max, d_max=max_d_max)

        base_params = list(map(lambda p: p[1], filter(lambda kv: kv[0] not in ['output.weight'],
                                                      model.named_parameters())))
        classifier_optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.output.parameters(), 'lr': inc_lr * 10.0}],
            lr=inc_lr, momentum=momentum, weight_decay=l2)
        # optimizer = torch.optim.SGD(
        #     model.parameters(), lr=inc_lr, momentum=momentum, weight_decay=l2
        # )

    train_x, train_y = train_batch
    train_x = preproc(train_x)
    seen_classes.update(train_y)

    if i > 0:
        list_memory = []
        new_labels = np.random.choice(list(seen_classes.difference(train_y)), rm_sz)
        print(set(train_y))
        print(set(new_labels))
        new_labels = torch.from_numpy(new_labels).type(torch.long)
        nr_batches_memory = rm_sz // 128
        if rm_sz % 128 != 0:
            nr_batches_memory = nr_batches_memory + 1
        for y in range(nr_batches_memory):
            batch_dim = min(128, (rm_sz - (y * 128)))
            samples = np.random.uniform(0.0, 3.2647081104914357, (batch_dim, 512 * 8 * 8))  # 90th percentile
            samples = torch.from_numpy(samples).type(torch.float)
            list_memory.append(samples.reshape((-1, 512, 8, 8)))
        generated_features = torch.cat(list_memory, dim=0)

        rm_add = [generated_features, new_labels]
        rm = copy.deepcopy(rm_add)

    if i == 0:
        cur_class = [int(o) for o in set(train_y)]
        model.cur_j = examples_per_class(train_y)
    else:
        cur_class = [int(o) for o in set(train_y).union(set(rm[1].tolist()))]
        model.cur_j = examples_per_class(list(train_y) + list(rm[1]))
    cur_class_batch = [int(o) for o in set(train_y)]

    for cc in cur_class_batch:
        if cf_labels_inc[cc] == -1:
            cf_labels_inc[cc] = cf_count_inc
            cf_count_inc += 1

    print("----------- batch {0} -------------".format(i))
    print("train_x shape: {}, train_y shape: {}"
          .format(train_x.shape, train_y.shape))
    print("CUR CLASS = {}".format(cur_class))

    model.train()
    model.lat_features.eval()

    reset_weights(model, cur_class_batch)
    cur_ep = 0

    # if i == 0:
    (train_x, train_y), it_x_ep = pad_data([train_x, train_y], mb_size)
    shuffle_in_unison([train_x, train_y], in_place=True)

    model = maybe_cuda(model, use_cuda=use_cuda)
    acc = None
    ave_loss = 0

    train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
    train_y = torch.from_numpy(train_y).type(torch.LongTensor)

    if i == 0:
        train_ep = init_train_ep
    else:
        train_ep = inc_train_ep

    for ep in range(train_ep):

        print("training ep: ", ep)
        correct_cnt, ave_loss = 0, 0

        # computing how many patterns to inject in the latent replay layer
        if i > 0:
            # cur_sz = train_x.size(0) // ((train_x.size(0) + rm_sz) // mb_size)
            cur_sz = 107
            it_x_ep = train_x.size(0) // cur_sz
            n2inject = max(0, mb_size - cur_sz)
        else:
            n2inject = 0
        print("total sz:", train_x.size(0) + rm_sz)
        print("n2inject", n2inject)
        print("it x ep: ", it_x_ep)

        for it in range(it_x_ep):

            if reg_lambda !=0:
                pre_update(model, synData)

            start = it * (mb_size - n2inject)
            end = (it + 1) * (mb_size - n2inject)

            optimizer.zero_grad()

            x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)

            if i == 0:
                lat_mb_x = None
                y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)

            else:
                lat_mb_x = rm[0][it*n2inject: (it + 1)*n2inject]
                lat_mb_y = rm[1][it*n2inject: (it + 1)*n2inject]
                y_mb = maybe_cuda(
                    torch.cat((train_y[start:end], lat_mb_y), 0),
                    use_cuda=use_cuda)
                lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)

            # if lat_mb_x is not None, this tensor will be concatenated in
            # the forward pass on-the-fly in the latent replay layer
            logits, lat_acts = model(
                x_mb, latent_input=lat_mb_x, return_lat_acts=True)

            _, pred_label = torch.max(logits, 1)
            correct_cnt += (pred_label == y_mb).sum()

            loss = criterion(logits, y_mb)
            if reg_lambda !=0:
                loss += compute_ewc_loss(model, ewcData, lambd=reg_lambda)
            ave_loss += loss.item()

            loss.backward()
            optimizer.step()

            if reg_lambda != 0:
                post_update(model, synData)

            acc = correct_cnt.item() / \
                  ((it + 1) * y_mb.size(0))
            ave_loss /= ((it + 1) * y_mb.size(0))

            if it % 10 == 0:
                print(
                    '==>>> it: {}, avg. loss: {:.6f}, '
                    'running train acc: {:.3f}'
                        .format(it, ave_loss, acc)
                )

            # Log scalar values (scalar summary) to TB
            tot_it_step +=1
            writer.add_scalar('classifier_train_loss', ave_loss, tot_it_step)
            writer.add_scalar('classifier_train_accuracy', acc, tot_it_step)

        cur_ep += 1

    consolidate_weights_nic(model, cur_class, cur_class_batch)
    if reg_lambda != 0:
        update_ewc_data(model, ewcData, synData, clip_to=0.001, c=0.00002)

    set_consolidate_weights(model)
    ave_loss, acc, accs, cf = get_accuracy_conf_matrix_nic(
        model, criterion, mb_size, test_x, test_y, preproc=preproc, cf_labels_inc=cf_labels_inc
    )

    # Log scalar values (scalar summary) to TB
    writer.add_scalar('classifier_test_loss', ave_loss, i)
    writer.add_scalar('classifier_test_accuracy', acc, i)

    df_cm_retrain = pd.DataFrame(cf, range(50), range(50))
    plot_buf = gen_plot(df_cm_retrain)
    image_retrain = Image.open(plot_buf)
    image_retrain = ToTensor()(image_retrain)
    writer.add_image('Confusion matrix', image_retrain, i)

    # update number examples encountered over time
    for c, n in model.cur_j.items():
        model.past_j[c] += n

    print("---------------------------------")
    print("Accuracy: ", acc)
    print("---------------------------------")

    if reg_lambda != 0:
        weight_stats(model, ewcData, clip_to=0.001)

writer.close()
