#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import os

# set cuda device (based on your hardware)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from data.core50_data_loader import CORE50
import copy
import json
from utils.utils import *
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_ar1_positive_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep_classifier = eval(exp_config['init_train_ep_classifier'])
inc_train_ep_classifier = eval(exp_config['inc_train_ep_classifier'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_classificator_sz = eval(exp_config['rm_classificator_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
ewc_lambda = eval(exp_config['ewc_lambda'])

learning_rate_cvae = eval(exp_config["learning_rate_cvae"])
optimizer_cvae = eval(exp_config["optimizer_cvae"])
recon_loss_cvae = eval(exp_config["recon_loss_cvae"])
init_train_ep_generator = eval(exp_config['init_train_ep_generator'])
inc_train_ep_generator = eval(exp_config['inc_train_ep_generator'])
use_kld = eval(exp_config["use_kld"])
use_distillation = eval(exp_config["use_distillation"])
alpha = eval(exp_config["alpha"])
beta = eval(exp_config["beta"])
gamma = eval(exp_config["gamma"])
latent_size = eval(exp_config["latent_size"])
rm_generator_sz = eval(exp_config["rm_generator_sz"])

# setting up log dir for tensorboard
log_dir = 'logs_NC/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)
n_runs = 1

dict_test_accuracy = {}
dict_test_loss = {}

for f in range(9):
    dict_test_loss[f] = []
    dict_test_accuracy[f] = []

for run in range(n_runs):
    print("#################### RUN {} ####################".format(run))
    # Other variables init
    tot_it_step_classifier = 0
    rm_classifier = None
    rm_generator = None

    # Create the dataset object
    dataset = CORE50(root='your/path/here', scenario="nc", run=2)
    preproc = preprocess_imgs

    # Get the fixed test set
    test_x_all, test_y_all = dataset.get_test_set()
    test_y_all = test_y_all.astype(np.long)

    # ********************** CLASSIFIER ******************************
    classifier_model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num)

    classifier_model.saved_weights = {}
    classifier_model.past_j = {i: 0 for i in range(50)}
    classifier_model.cur_j = {i: 0 for i in range(50)}
    if ewc_lambda != 0:
        ewcData, synData = create_syn_data(classifier_model)

    classifier_optimizer = torch.optim.SGD(
        classifier_model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
    )

    # ********************************** UTILITIES ************************************************
    seen_classes = set()
    nr_seen_classes = 0
    ce_loss = torch.nn.CrossEntropyLoss()

    ce_loss_no_reduction = torch.nn.CrossEntropyLoss(reduction="none")

    lsm = torch.nn.LogSoftmax(dim=1)
    nll = torch.nn.NLLLoss(reduction="none")

    # *************************************************************************************************************
    # *************************************************************************************************************
    # ***************************************  TRAINING PROCEDURE  ************************************************
    # *************************************************************************************************************
    # *************************************************************************************************************

    # loop over the training incremental batches
    for i, train_batch in enumerate(dataset):

        if ewc_lambda != 0:
            init_batch(classifier_model, ewcData, synData)

        # freeze below conv5_4/dw
        freeze_up_to(classifier_model, freeze_below_layer, only_conv=False)

        if i == 1:
            base_params = list(map(lambda p: p[1], filter(lambda kv: kv[0] not in ['output.weight'],
                                                          classifier_model.named_parameters())))
            classifier_optimizer = torch.optim.SGD([
                {'params': base_params},
                {'params': classifier_model.output.parameters(), 'lr': inc_lr * 10}],
                lr=inc_lr, momentum=momentum, weight_decay=l2
            )

        train_x, train_y = train_batch
        train_x = preproc(train_x)
        train_y = train_y.astype(np.long)

        seen_classes.update(train_y)
        if len(seen_classes) > nr_seen_classes:
            print("New classes here!")
            recalculate_test_set = True
            nr_seen_classes = len(seen_classes)

        if i == 0:
            cur_class = [int(o) for o in set(train_y)]
            cur_class_batch = [int(o) for o in set(train_y)]
            classifier_model.cur_j = examples_per_class(train_y)
        else:
            cur_class = [int(o) for o in set(train_y).union(set(rm_classifier[1].tolist()))]
            cur_class_batch = [int(o) for o in set(train_y)]
            classifier_model.cur_j = examples_per_class(list(train_y) + list(rm_classifier[1]))

        print(cur_class)
        print(classifier_model.cur_j)

        print("----------- batch {0} -------------".format(i))
        print("train_x shape: {}, train_y shape: {}"
              .format(train_x.shape, train_y.shape))

        # Transform the datas
        (train_x, train_y), it_x_ep = pad_data([train_x, train_y], mb_size)
        shuffle_in_unison([train_x, train_y], in_place=True)
        train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
        train_y = torch.from_numpy(train_y).type(torch.LongTensor)

        # *************************************************************************************************************
        # *************************************** CLASSIFIER TRAINING *************************************************
        # *************************************************************************************************************
        print("************************ CLASSIFIER TRAINING ****************************************************")

        reset_weights(classifier_model, cur_class_batch)
        classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

        cur_ep_classifier = 0

        acc_classifier = None
        ave_loss_classifier = 0

        if i == 0:
            train_ep_classifier = init_train_ep_classifier
        else:
            train_ep_classifier = inc_train_ep_classifier

        for ep in range(train_ep_classifier):

            print("training ep: ", ep)
            correct_cnt, ave_loss_classifier = 0, 0

            if i > 0:
                cur_sz = train_x.size(0) // ((train_x.size(0) + rm_classificator_sz) // mb_size)
                # cur_sz = 127
                it_x_ep = train_x.size(0) // cur_sz
                n2inject = max(0, mb_size - cur_sz)
            else:
                cur_sz = mb_size
                n2inject = 0
            print("total sz:", train_x.size(0) + rm_classificator_sz)
            print("n2inject", n2inject)
            print("it x ep: ", it_x_ep)

            for it in range(it_x_ep):

                if ewc_lambda !=0:
                    pre_update(classifier_model, synData)

                start = it * (mb_size - n2inject)
                end = (it + 1) * (mb_size - n2inject)

                classifier_optimizer.zero_grad()

                x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)
                y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)

                if i == 0:
                    lat_mb_x = None
                else:
                    lat_mb_x = np.take(rm_classifier[0], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                    lat_mb_y = np.take(rm_classifier[1], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                    y_mb = maybe_cuda(
                        torch.cat((train_y[start:end], lat_mb_y), 0),
                        use_cuda=use_cuda)
                    lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)

                freeze_up_to(classifier_model, freeze_below_layer, only_conv=False)
                classifier_model.train()
                classifier_model.lat_features.eval()
                logits, lat_acts = classifier_model(x=x_mb, latent_input=lat_mb_x, return_lat_acts=True)
                # collect latent volumes only for the first ep
                # we need to store them to eventually add them into the external
                # replay memory
                if ep == 0:
                    lat_acts = lat_acts.cpu().detach()
                    if it == 0:
                        cur_acts = copy.deepcopy(lat_acts)
                    else:
                        cur_acts = torch.cat((cur_acts, lat_acts), 0)

                loss = ce_loss(logits, y_mb)
                ave_loss_classifier += loss.mean().item()
                loss.backward()
                _, pred_label = torch.max(logits[0: mb_size - n2inject], 1)
                correct_cnt += (pred_label == y_mb[0: mb_size - n2inject]).sum()

                if ewc_lambda != 0:
                    loss += compute_ewc_loss(classifier_model, ewcData, lambd=ewc_lambda)

                classifier_optimizer.step()

                if ewc_lambda != 0:
                    post_update(classifier_model, synData)

                acc_classifier = correct_cnt.item() / \
                                 ((it + 1) * y_mb[0: mb_size - n2inject].size(0))
                ave_loss_classifier /= ((it + 1) * y_mb.size(0))

                if it % 10 == 0:
                    print(
                        '==>>> it: {}, avg. loss: {:.6f}, '
                        'running train acc: {:.3f}'
                            .format(it, ave_loss_classifier, acc_classifier)
                    )
                # Log scalar values (scalar summary) to TB
                tot_it_step_classifier += 1
                writer.add_scalar('classifier_train_loss', ave_loss_classifier, tot_it_step_classifier)
                writer.add_scalar('classifier_train_accuracy', acc_classifier, tot_it_step_classifier)

            cur_ep_classifier += 1

        w_changes = consolidate_weights(classifier_model, cur_class, cur_class)
        if ewc_lambda != 0:
            update_ewc_data(classifier_model, ewcData, synData, 0.001, 1)

        # update number examples encountered over time
        for c, n in classifier_model.cur_j.items():
            classifier_model.past_j[c] += n
        set_consolidate_weights(classifier_model)

        ave_loss_classifier, acc_classifier, accs, cf = get_accuracy_conf_matrix(
            classifier_model, ce_loss, mb_size, test_x_all, test_y_all, preproc=preproc
        )

        # ------------------ Statistics classifier -------------------------------------
        writer.add_scalar('classifier_test_loss', ave_loss_classifier, i)
        writer.add_scalar('classifier_test_accuracy', acc_classifier, i)
        dict_test_accuracy[i].append(acc_classifier)
        dict_test_loss[i].append(ave_loss_classifier)

        print("---------------------------------")
        print("Accuracy classifier: ", acc_classifier)
        print("---------------------------------")

        w_mag = [0.0 for _ in range(50)]
        print("weight_magnitude")
        with torch.no_grad():
            for we in range(50):
                w_mean = np.average(np.abs(classifier_model.output.weight.detach().cpu().numpy()[we]))
                w_mag[we] = w_mean

        # *************************************************************************************************************
        # *************************************** REPLAY MEMORY CALCULATION *******************************************
        # *************************************************************************************************************

        print("************************ MEMORY CALCULATION ****************************************************")

        # --------------------------- Classificator replay memory ---------------------------------

        print("fill memory...")
        # how many patterns to save for next iter
        h = min(rm_classificator_sz // (i + 1), cur_acts.size(0))
        print("h", h)

        print("cur_acts sz:", cur_acts.size(0))
        idxs_cur = np.random.choice(
            cur_acts.size(0), h, replace=False
        )
        rm_add = [cur_acts[idxs_cur], train_y[idxs_cur]]
        print("rm_add size", rm_add[0].size(0))

        # replace patterns in random memory
        if i == 0:
            rm_classifier = copy.deepcopy(rm_add)
        else:
            idxs_2_replace = np.random.choice(
                rm_classifier[0].size(0), h, replace=False
            )
            rm_add_cp = copy.deepcopy(rm_add)
            for j, idx in enumerate(idxs_2_replace):
                rm_classifier[0][idx] = rm_add[0][j]
                rm_classifier[1][idx] = rm_add[1][j]
        shuffle_in_unison_pytorch([rm_classifier[0], rm_classifier[1]])

writer.close()
