#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import os

# set cuda device (based on your hardware)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from data.core50_data_loader import CORE50
import copy
import json
from utils.utils import *
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_ar1_random_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep_classifier = eval(exp_config['init_train_ep_classifier'])
inc_train_ep_classifier = eval(exp_config['inc_train_ep_classifier'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_classificator_sz = eval(exp_config['rm_classificator_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
ewc_lambda = eval(exp_config['ewc_lambda'])

learning_rate_cvae = eval(exp_config["learning_rate_cvae"])
optimizer_cvae = eval(exp_config["optimizer_cvae"])
recon_loss_cvae = eval(exp_config["recon_loss_cvae"])
init_train_ep_generator = eval(exp_config['init_train_ep_generator'])
inc_train_ep_generator = eval(exp_config['inc_train_ep_generator'])
use_kld = eval(exp_config["use_kld"])
use_distillation = eval(exp_config["use_distillation"])
alpha = eval(exp_config["alpha"])
beta = eval(exp_config["beta"])
gamma = eval(exp_config["gamma"])
latent_size = eval(exp_config["latent_size"])
rm_generator_sz = eval(exp_config["rm_generator_sz"])

# setting up log dir for tensorboard
log_dir = 'logs_NC/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)
n_runs = 1

dict_test_accuracy = {}
dict_test_loss = {}

for f in range(9):
    dict_test_loss[f] = []
    dict_test_accuracy[f] = []

for run in range(n_runs):
    print("#################### RUN {} ####################".format(run))
    # Other variables init
    tot_it_step_classifier = 0
    rm_classifier = None
    rm_generator = None

    # Create the dataset object
    dataset = CORE50(root='your/path/here', scenario="nc", run=2)
    preproc = preprocess_imgs

    # Get the fixed test set
    test_x_all, test_y_all = dataset.get_test_set()
    test_y_all = test_y_all.astype(np.long)

    # ********************** CLASSIFIER ******************************
    classifier_model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num)

    classifier_model.saved_weights = {}
    classifier_model.past_j = {i: 0 for i in range(50)}
    classifier_model.cur_j = {i: 0 for i in range(50)}
    if ewc_lambda != 0:
        ewcData, synData = create_syn_data(classifier_model)

    classifier_optimizer = torch.optim.SGD(
        classifier_model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
    )

    # ********************************** UTILITIES ************************************************
    seen_classes = set()
    nr_seen_classes = 0
    ce_loss = torch.nn.CrossEntropyLoss()

    ce_loss_no_reduction = torch.nn.CrossEntropyLoss(reduction="none")

    lsm = torch.nn.LogSoftmax(dim=1)
    nll = torch.nn.NLLLoss(reduction="none")

    # *************************************************************************************************************
    # *************************************************************************************************************
    # ***************************************  TRAINING PROCEDURE  ************************************************
    # *************************************************************************************************************
    # *************************************************************************************************************

    # loop over the training incremental batches
    for i, train_batch in enumerate(dataset):

        if ewc_lambda != 0:
            init_batch(classifier_model, ewcData, synData)

        # freeze below conv5_4/dw
        freeze_up_to(classifier_model, freeze_below_layer, only_conv=False)

        if i == 1:
            base_params = list(map(lambda p: p[1], filter(lambda kv: kv[0] not in ['output.weight'],
                                                          classifier_model.named_parameters())))
            classifier_optimizer = torch.optim.SGD([
                {'params': base_params},
                {'params': classifier_model.output.parameters(), 'lr': inc_lr * 10}],
                lr=inc_lr, momentum=momentum, weight_decay=l2
            )

        train_x, train_y = train_batch
        train_x = preproc(train_x)
        train_y = train_y.astype(np.long)

        seen_classes.update(train_y)
        if len(seen_classes) > nr_seen_classes:
            print("New classes here!")
            recalculate_test_set = True
            nr_seen_classes = len(seen_classes)

        if i == 0:
            cur_class = [int(o) for o in set(train_y)]
            cur_class_batch = [int(o) for o in set(train_y)]
            classifier_model.cur_j = examples_per_class(train_y)
        else:
            cur_class = [int(o) for o in set(train_y).union(set(rm_classifier[1].tolist()))]
            cur_class_batch = [int(o) for o in set(train_y)]
            classifier_model.cur_j = examples_per_class(list(train_y) + list(rm_classifier[1]))

        print(cur_class)
        print(classifier_model.cur_j)

        print("----------- batch {0} -------------".format(i))
        print("train_x shape: {}, train_y shape: {}"
              .format(train_x.shape, train_y.shape))

        # Transform the datas
        (train_x, train_y), it_x_ep = pad_data([train_x, train_y], mb_size)
        shuffle_in_unison([train_x, train_y], in_place=True)
        train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
        train_y = torch.from_numpy(train_y).type(torch.LongTensor)

        # *************************************************************************************************************
        # *************************************** CLASSIFIER TRAINING *************************************************
        # *************************************************************************************************************
        print("************************ CLASSIFIER TRAINING ****************************************************")

        reset_weights(classifier_model, cur_class_batch)
        classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

        cur_ep_classifier = 0

        acc_classifier = None
        ave_loss_classifier = 0

        if i == 0:
            train_ep_classifier = init_train_ep_classifier
        else:
            train_ep_classifier = inc_train_ep_classifier

        for ep in range(train_ep_classifier):

            print("training ep: ", ep)
            correct_cnt, ave_loss_classifier = 0, 0

            if i > 0:
                cur_sz = train_x.size(0) // ((train_x.size(0) + rm_classificator_sz) // mb_size)
                it_x_ep = train_x.size(0) // cur_sz
                n2inject = max(0, mb_size - cur_sz)
            else:
                cur_sz = mb_size
                n2inject = 0
            print("total sz:", train_x.size(0) + rm_classificator_sz)
            print("n2inject", n2inject)
            print("it x ep: ", it_x_ep)

            for it in range(it_x_ep):

                if ewc_lambda !=0:
                    pre_update(classifier_model, synData)

                start = it * (mb_size - n2inject)
                end = (it + 1) * (mb_size - n2inject)

                classifier_optimizer.zero_grad()

                x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)
                y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)

                if i == 0:
                    lat_mb_x = None
                else:
                    lat_mb_x = np.take(rm_classifier[0], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                    lat_mb_y = np.take(rm_classifier[1], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                    y_mb = maybe_cuda(
                        torch.cat((train_y[start:end], lat_mb_y), 0),
                        use_cuda=use_cuda)
                    lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)

                freeze_up_to(classifier_model, freeze_below_layer, only_conv=False)
                classifier_model.train()
                classifier_model.lat_features.eval()

                logits, lat_acts = classifier_model(x=x_mb, latent_input=lat_mb_x, return_lat_acts=True)
                if i > 0:
                    mask = set()
                    mask.update(_ for _ in range(50))
                    mask.difference_update(cur_class_batch)

                    def mask_hook(module, in_grad, out_grad):
                        ret_grad = copy.deepcopy(in_grad[0].detach())
                        ret_grad[:, list(mask)] = 0.
                        return (ret_grad,)

                    hook_handler = lsm.register_backward_hook(mask_hook)
                    loss_1 = nll(lsm(logits), y_mb)
                    loss_1[mb_size - n2inject: mb_size].backward(
                        maybe_cuda(torch.full((n2inject,), 1 / mb_size), use_cuda=use_cuda),
                        retain_graph=True
                    )

                    hook_handler.remove()

                loss = nll(lsm(logits), y_mb)
                ave_loss_classifier += loss.mean().item()
                loss[0: mb_size - n2inject].backward(
                    maybe_cuda(torch.full((mb_size - n2inject,), 1 / mb_size), use_cuda=use_cuda)
                )
                _, pred_label = torch.max(logits[0: mb_size - n2inject], 1)
                correct_cnt += (pred_label == y_mb[0: mb_size - n2inject]).sum()

                if ewc_lambda != 0:
                    loss += compute_ewc_loss(classifier_model, ewcData, lambd=ewc_lambda)

                classifier_optimizer.step()

                if ewc_lambda != 0:
                    post_update(classifier_model, synData)

                acc_classifier = correct_cnt.item() / \
                                 ((it + 1) * y_mb[0: mb_size - n2inject].size(0))
                ave_loss_classifier /= ((it + 1) * y_mb.size(0))

                if it % 10 == 0:
                    print(
                        '==>>> it: {}, avg. loss: {:.6f}, '
                        'running train acc: {:.3f}'
                            .format(it, ave_loss_classifier, acc_classifier)
                    )
                # Log scalar values (scalar summary) to TB
                tot_it_step_classifier += 1
                writer.add_scalar('classifier_train_loss', ave_loss_classifier, tot_it_step_classifier)
                writer.add_scalar('classifier_train_accuracy', acc_classifier, tot_it_step_classifier)

            cur_ep_classifier += 1

        w_changes = consolidate_weights(classifier_model, cur_class, cur_class_batch)
        if ewc_lambda != 0:
            update_ewc_data(classifier_model, ewcData, synData, 0.001, 1)

        # update number examples encountered over time
        for c, n in classifier_model.cur_j.items():
            classifier_model.past_j[c] += n
        set_consolidate_weights(classifier_model)

        ave_loss_classifier, acc_classifier, accs, cf = get_accuracy_conf_matrix(
            classifier_model, ce_loss, mb_size, test_x_all, test_y_all, preproc=preproc
        )

        # ------------------ Statistics classifier -------------------------------------
        writer.add_scalar('classifier_test_loss', ave_loss_classifier, i)
        writer.add_scalar('classifier_test_accuracy', acc_classifier, i)
        dict_test_accuracy[i].append(acc_classifier)
        dict_test_loss[i].append(ave_loss_classifier)

        print("---------------------------------")
        print("Accuracy classifier: ", acc_classifier)
        print("---------------------------------")

        w_mag = [0.0 for _ in range(50)]
        print("weight_magnitude")
        with torch.no_grad():
            for we in range(50):
                w_mean = np.average(np.abs(classifier_model.output.weight.detach().cpu().numpy()[we]))
                w_mag[we] = w_mean

        # *************************************************************************************************************
        # *************************************** REPLAY MEMORY CALCULATION *******************************************
        # *************************************************************************************************************

        print("************************ MEMORY CALCULATION ****************************************************")

        # --------------------------- Classificator replay memory ---------------------------------

        list_memory = []
        # generate latent features
        new_labels = np.random.choice(len(seen_classes), rm_classificator_sz)
        new_labels = torch.from_numpy(new_labels).type(torch.long)
        nr_batches_memory = rm_classificator_sz // 128
        if rm_classificator_sz % 128 != 0:
            nr_batches_memory = nr_batches_memory + 1
        for y in range(nr_batches_memory):
            batch_dim = min(128, (rm_classificator_sz - (y * 128)))
            samples = np.random.uniform(0.0, 3.2647081104914357, (batch_dim, 512 * 8 * 8))  # 90th percentile
            samples = torch.from_numpy(samples).type(torch.float)
            list_memory.append(samples.reshape((-1, 512, 8, 8)))
        generated_features = torch.cat(list_memory, dim=0)

        rm_add = [generated_features, new_labels]
        rm_classifier = copy.deepcopy(rm_add)

writer.close()
