#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import random
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from cl_tools import NCProtocolIterator, TransformationSubset
from utils.cwr_utils import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# set cuda device (based on your hardware)
from utils.imagenet_benchmark import make_imagenet_40_25_benchmark

import json
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter
import PIL
import pandas as pd
from torchvision.transforms import ToTensor
import numpy as np
from models import change_bn_momentum
import math

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name', default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_real_positive_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
cwr_inc_lr = eval(exp_config['cwr_inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep_classifier = eval(exp_config['init_train_ep_classifier'])
inc_train_ep_classifier = eval(exp_config['inc_train_ep_classifier'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_classificator_sz = eval(exp_config['rm_classificator_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer = eval(exp_config['latent_layer'])
ewc_lambda = eval(exp_config['ewc_lambda'])
wi = eval(exp_config['wi'])

# setting up log dir for tensorboard
log_dir = 'logs/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)
n_runs = 1

dict_test_accuracy = defaultdict(list)
dict_test_loss = defaultdict(list)


def pad_dataset(dataset, mb_size, shuffle_dataset=False):
    x = dataset
    # computing test_iters
    n_missing = len(x) % mb_size
    if n_missing > 0:
        surplus = 1
    else:
        surplus = 0
    it = len(x) // mb_size + surplus

    # padding data to fix batch dimentions
    if n_missing > 0:
        n_to_add = mb_size - n_missing
        padded_indexes = list(range(n_to_add)) + list(range(len(dataset)))
        if shuffle_dataset:
            random.shuffle(padded_indexes)
        dataset = TransformationSubset(dataset, padded_indexes)

    return dataset, it


# for run in range(n_runs):
run = 0
print("#################### RUN {} ####################".format(run))
protocol, transform, transform_test, classifier_model = make_imagenet_40_25_benchmark(
    class_order_pkl='../data/seed_1993_imagenet_order_run_0.pkl')
print(classifier_model)

# Other variables init
tot_it_step_classifier = 0
rm_classifier = None

# ********************** CLASSIFIER ******************************

if latent_layer == "":
    latent_layer = None

classifier_model.saved_weights = {}
classifier_model.past_j = {i: 0 for i in range(protocol.n_classes)}
classifier_model.cur_j = {i: 0 for i in range(protocol.n_classes)}

if ewc_lambda != 0:
    ewcData, synData = create_syn_data(classifier_model, classification_layer='fc')

classifier_optimizer = torch.optim.SGD(
    classifier_model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
)

change_bn_momentum(classifier_model, momentum=0.001)

# ********************************** UTILITIES ************************************************
seen_classes = set()
nr_seen_classes = 0
ce_loss = torch.nn.CrossEntropyLoss()

ce_loss_no_reduction = torch.nn.CrossEntropyLoss(reduction="none")

lsm = torch.nn.LogSoftmax(dim=1)
nll = torch.nn.NLLLoss(reduction="none")

lat_act_batch = []
lat_label_batch = []

# *************************************************************************************************************
# *************************************************************************************************************
# ***************************************  TRAINING PROCEDURE  ************************************************
# *************************************************************************************************************
# *************************************************************************************************************

classifier_model.train()

# loop over the training incremental batches
batch_info: NCProtocolIterator
for i, (train_batch, batch_info) in enumerate(protocol):
    print("----------- batch {0} -------------".format(i))

    if ewc_lambda != 0:
        init_batch(classifier_model, ewcData, synData, classification_layer='fc')

    train_y_all = np.array(train_batch.targets)
    print("train lab - min: {}".format(min(train_y_all)))
    print("train lab - max: {}".format(max(train_y_all)))
    print("Nr. classes in batch: {}".format(len(set(train_y_all))))

    if i == 1:
        if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
            raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
        standard_params_lr, cwr_params_lr, frozen_params_lr = \
            ar1f_get_params_lr(classifier_model, 'fc', freeze_below_layer=freeze_below_layer)

        classifier_optimizer = torch.optim.SGD([
            {'params': standard_params_lr},  # Whole net - fc.weight - fc.bias
            {'params': cwr_params_lr, 'lr': cwr_inc_lr},  # fc.weight
            {'params': frozen_params_lr, 'lr': 0}  # fc.bias
        ], lr=inc_lr, momentum=momentum, weight_decay=l2)

    seen_classes.update(train_y_all)
    if len(seen_classes) > nr_seen_classes:
        print("New classes here!")
        recalculate_test_set = True
        nr_seen_classes = len(seen_classes)

    if i == 0:
        cur_class = [int(o) for o in set(train_y_all)]
        cur_class_batch = [int(o) for o in set(train_y_all)]
        classifier_model.cur_j = examples_per_class(train_y_all)
    else:
        cur_class = [int(o) for o in set(train_y_all).union(set(rm_classifier[1].tolist()))]
        cur_class_batch = [int(o) for o in set(train_y_all)]
        classifier_model.cur_j = examples_per_class(list(train_y_all) + rm_classifier[1].tolist())

    if i > 0:
        print("labels in memory:",
              examples_per_class(rm_classifier[1].tolist(), classes_n=max(rm_classifier[1].tolist()) + 1))

        train_batch, it_x_ep = pad_dataset(train_batch, mb_size, shuffle_dataset=True)
        print("Padded dataset has {} patterns".format(len(train_batch)))
    else:
        print("train_y shape: {}".format(train_y_all.shape))

    # *************************************************************************************************************
    # *************************************** CLASSIFIER TRAINING *************************************************
    # *************************************************************************************************************
    print("************************ CLASSIFIER TRAINING ****************************************************")

    if i > 0:
        reset_weights(classifier_model, cur_class)
    classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

    cur_ep_classifier = 0

    acc_classifier = None
    ave_loss_classifier = 0

    if i == 0:
        train_ep_classifier = init_train_ep_classifier
    else:
        train_ep_classifier = inc_train_ep_classifier
    lr_multiplier = -(0.5 / (1 + math.pow(math.e, -((1.5 * i) - 8)))) + 1
    classifier_optimizer.param_groups[0]['lr'] = inc_lr * lr_multiplier

    h = min(rm_classificator_sz // (i + 1), len(train_batch))
    print("nr. pattern to insert in memory: {}".format(h))

    for ep in range(train_ep_classifier):
        if i == 0:
            # Skip training for first batch
            print('Skipping training for the first batch')
            print("but have to calculate memory!")

            lat_act_batch = []
            lat_label_batch = []
            train_loader = DataLoader(train_batch, batch_size=mb_size, shuffle=True,
                                      pin_memory=True, drop_last=False, num_workers=8)
            classifier_model.eval()
            with torch.no_grad():
                for it, (train_x, train_y) in enumerate(train_loader):
                    train_x = maybe_cuda(train_x)
                    _, lat_acts = classifier_model(train_x, None, latent_layer=latent_layer, return_lat_acts=True)
                    lat_act_batch.append(lat_acts.detach().cpu())
                    lat_label_batch.append(train_y.detach().cpu())
                    if mb_size * it > h:
                        break
                    if it % 100 == 0:
                        print("{}/{}".format((it + 1) * 128, len(train_batch)))

            break

        if ep == (train_ep_classifier // 2):
            for group in classifier_optimizer.param_groups:
                group['lr'] = group['lr'] * 0.1

        print("training ep: ", ep)
        print("lr = {}".format([group['lr'] for group in classifier_optimizer.param_groups]))
        correct_cnt, ave_loss_classifier = 0, 0

        if i > 0:
            cur_sz = len(train_batch) // ((len(train_batch) + rm_classificator_sz) // mb_size)
            it_x_ep = len(train_batch) // cur_sz
            n2inject = max(0, mb_size - cur_sz)
        else:
            n2inject = 0

        print("total sz:", len(train_batch) + rm_classificator_sz)
        print("n2inject", n2inject)

        train_loader = DataLoader(train_batch, batch_size=(mb_size - n2inject), shuffle=True,
                                  pin_memory=True, drop_last=True, num_workers=8)

        print("it x ep: ", len(train_loader))

        if ep == 0:  # first epoch of the batch
            lat_act_batch = []
            lat_label_batch = []

        for it, (train_x, train_y) in enumerate(train_loader):
            if ewc_lambda != 0:
                pre_update(classifier_model, synData, classification_layer='fc')

            start = it * (mb_size - n2inject)
            end = (it + 1) * (mb_size - n2inject)

            classifier_optimizer.zero_grad()

            y_mb = train_y

            if n2inject > 0:
                lat_mb_x = np.take(rm_classifier[0], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                lat_mb_y = np.take(rm_classifier[1], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                y_mb = torch.cat((train_y, lat_mb_y), 0)

            if n2inject > 0:
                lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)
            x_mb = maybe_cuda(train_x, use_cuda=use_cuda)
            y_mb = maybe_cuda(y_mb, use_cuda=use_cuda)

            if i > 0:
                classifier_model.train()
                if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
                    raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
            else:
                classifier_model.train()

            if n2inject > 0:
                logits, lat_acts = classifier_model(x_mb, lat_mb_x, latent_layer=latent_layer, return_lat_acts=True)
                if ep == train_ep_classifier - 1 and (cur_sz * it) < h:
                    lat_act_batch.append(lat_acts.detach().cpu())
                    lat_label_batch.append(train_y)
            else:
                logits = classifier_model(x_mb)
            loss = ce_loss(logits, y_mb)
            ave_loss_classifier += loss.mean().item()
            loss.backward()
            _, pred_label = torch.max(logits[0: mb_size - n2inject], 1)
            correct_cnt += (pred_label == y_mb[0: mb_size - n2inject]).sum()

            if ewc_lambda != 0:
                loss += compute_ewc_loss(classifier_model, ewcData, lambd=ewc_lambda,
                                         classification_layer='fc')

            classifier_optimizer.step()

            if ewc_lambda != 0:
                post_update(classifier_model, synData, classification_layer='fc')

            acc_classifier = correct_cnt.item() / \
                             ((it + 1) * y_mb[0: mb_size - n2inject].size(0))

            if it % 10 == 0:
                print(
                    '==>>> it: {}, avg. loss: {:.6f}, '
                    'running train acc: {:.3f}'.format(it, ave_loss_classifier / (it + 1), acc_classifier))
                writer.add_scalar('classifier_train_loss', ave_loss_classifier / (it + 1), tot_it_step_classifier)
                writer.add_scalar('classifier_train_accuracy', acc_classifier, tot_it_step_classifier)
            tot_it_step_classifier += 1
        cur_ep_classifier += 1

    for group in classifier_optimizer.param_groups:
        group['lr'] = group['lr'] * 10

    # replay positive
    w_changes = consolidate_weights(classifier_model, cur_class, cur_class)  # positive replay
    if ewc_lambda != 0:
        update_ewc_data(classifier_model, ewcData, synData, clip_to=0.001, c=wi, classification_layer='fc')

    # update number examples encountered over time
    for c, n in classifier_model.cur_j.items():
        classifier_model.past_j[c] += n
    set_consolidate_weights(classifier_model)

    test_loader = DataLoader(batch_info.get_cumulative_test_set(), batch_size=mb_size, shuffle=False,
                             drop_last=False, num_workers=8, pin_memory=True)

    ave_loss_classifier, acc_classifier, accs, cf = get_accuracy_conf_matrix_from_dataloader(
        classifier_model, ce_loss, test_loader, protocol.n_classes)

    # ------------------ Statistics classifier -------------------------------------
    writer.add_scalar('classifier_test_loss', ave_loss_classifier, i)
    writer.add_scalar('classifier_test_accuracy', acc_classifier, i)

    hit_per_class, tot_test_class = accs

    df_cm = pd.DataFrame(cf, range(protocol.n_classes), range(protocol.n_classes))
    plot_buf = gen_plot(df_cm, max(tot_test_class))
    image_cf = PIL.Image.open(plot_buf)
    image_cf = ToTensor()(image_cf)
    writer.add_image('classifier_confusion_matrix/run_{}'.format(run), image_cf, i)

    print("---------------------------------")
    print("Accuracy classifier: ", acc_classifier)
    print("---------------------------------")

    for past_task in range(i + 1):
        tot_hit = sum(hit_per_class[past_task * 40: (past_task * 40) + 40])
        tot_class_task = sum(tot_test_class[past_task * 40: (past_task * 40) + 40])
        acc_task = tot_hit / tot_class_task
        writer.add_scalar('task_accuracy/task_{}'.format(past_task), acc_task, i)

    for fut_task in range(i + 1, 25):
        writer.add_scalar('task_accuracy/task_{}'.format(fut_task), 0.0, i)

    if ewc_lambda != 0:
        weight_stats(classifier_model, ewcData, clip_to=0.001, classification_layer='fc')

    # *************************************************************************************************************
    # *************************************** REPLAY MEMORY CALCULATION *******************************************
    # *************************************************************************************************************

    print("************************ MEMORY CALCULATION ****************************************************")

    # --------------------------- Classificator replay memory ---------------------------------
    if rm_classificator_sz > 0:
        print("Fill replay memory")
        cur_acts = torch.cat(lat_act_batch, dim=0)
        print("dim cur act", cur_acts.shape)
        cur_labels = torch.cat(lat_label_batch, dim=0)
        print("dim cur labels", cur_labels.shape)

        # how many patterns to save for next iter
        h = min(rm_classificator_sz // (i + 1), cur_acts.size(0))
        print("h", h)
        # replace patterns in random memory
        if i == 0:
            rm_classifier = [cur_acts[:h], cur_labels[:h]]
        else:
            idxs_2_replace = np.random.choice(
                rm_classifier[0].size(0), h, replace=False
            )
            for j, idx in enumerate(idxs_2_replace):
                rm_classifier[0][idx] = cur_acts[j]
                rm_classifier[1][idx] = cur_labels[j]
    else:
        rm_classifier = [torch.as_tensor([], dtype=torch.float), torch.as_tensor([], dtype=torch.long)]


    # ########################  END OF BATCH  #################################
# ########################  END OF RUN  #################################
writer.close()
