#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import random
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from cl_tools import NCProtocolIterator, TransformationSubset
from utils.cwr_utils import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# set cuda device (based on your hardware)
from utils.imagenet_benchmark import make_imagenet_40_25_benchmark

import json
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter
import PIL
import pandas as pd
from torchvision.transforms import ToTensor
import numpy as np
from models import change_bn_momentum
import math

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_random_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
cwr_inc_lr = eval(exp_config['cwr_inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep_classifier = eval(exp_config['init_train_ep_classifier'])
inc_train_ep_classifier = eval(exp_config['inc_train_ep_classifier'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_classificator_sz = eval(exp_config['rm_classificator_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer = eval(exp_config['latent_layer'])
ewc_lambda = eval(exp_config['ewc_lambda'])
wi = eval(exp_config['wi'])

# setting up log dir for tensorboard
log_dir = 'logs/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)
n_runs = 1

dict_test_accuracy = defaultdict(list)
dict_test_loss = defaultdict(list)


def pad_dataset(dataset, mb_size, shuffle_dataset=False):
    x = dataset
    # computing test_iters
    n_missing = len(x) % mb_size
    if n_missing > 0:
        surplus = 1
    else:
        surplus = 0
    it = len(x) // mb_size + surplus

    # padding data to fix batch dimentions
    if n_missing > 0:
        n_to_add = mb_size - n_missing
        padded_indexes = list(range(n_to_add)) + list(range(len(dataset)))
        if shuffle_dataset:
            random.shuffle(padded_indexes)
        dataset = TransformationSubset(dataset, padded_indexes)

    return dataset, it


# for run in range(n_runs):
run = 0
print("#################### RUN {} ####################".format(run))
# protocol, transform, transform_test, classifier_model = make_imagenet_fearnet_benchmark()
protocol, transform, transform_test, classifier_model = make_imagenet_40_25_benchmark(
    class_order_pkl='../data/seed_1993_imagenet_order_run_0.pkl')
print(classifier_model)

# Other variables init
tot_it_step_classifier = 0
rm_classifier = None

classifier_model.saved_weights = {}
classifier_model.past_j = {i: 0 for i in range(protocol.n_classes)}
classifier_model.cur_j = {i: 0 for i in range(protocol.n_classes)}

if ewc_lambda != 0:
    ewcData, synData = create_syn_data(classifier_model, classification_layer='fc')

classifier_optimizer = torch.optim.SGD(
    classifier_model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
)

change_bn_momentum(classifier_model, momentum=0.001)

# ********************************** UTILITIES ************************************************
seen_classes = set()
nr_seen_classes = 0
ce_loss = torch.nn.CrossEntropyLoss()

ce_loss_no_reduction = torch.nn.CrossEntropyLoss(reduction="none")

lsm = torch.nn.LogSoftmax(dim=1)
nll = torch.nn.NLLLoss(reduction="none")

# *************************************************************************************************************
# *************************************************************************************************************
# ***************************************  TRAINING PROCEDURE  ************************************************
# *************************************************************************************************************
# *************************************************************************************************************

classifier_model.train()

# loop over the training incremental batches
batch_info: NCProtocolIterator
for i, (train_batch, batch_info) in enumerate(protocol):
    print("----------- batch {0} -------------".format(i))

    if ewc_lambda != 0:
        init_batch(classifier_model, ewcData, synData, classification_layer='fc')

    train_y_all = np.array(train_batch.targets)
    print("train lab - min: {}".format(min(train_y_all)))
    print("train lab - max: {}".format(max(train_y_all)))
    print("Nr. classes in batch: {}".format(len(set(train_y_all))))

    if i == 1:
        if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
            raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
        standard_params_lr, cwr_params_lr, frozen_params_lr = \
            ar1f_get_params_lr(classifier_model, 'fc', freeze_below_layer=freeze_below_layer)

        classifier_optimizer = torch.optim.SGD([
            {'params': standard_params_lr},  # Whole net - fc.weight - fc.bias
            {'params': cwr_params_lr, 'lr': cwr_inc_lr},  # fc.weight
            {'params': frozen_params_lr, 'lr': 0}  # fc.bias
        ], lr=inc_lr, momentum=momentum, weight_decay=l2)

    seen_classes.update(train_y_all)
    if len(seen_classes) > nr_seen_classes:
        print("New classes here!")
        recalculate_test_set = True
        nr_seen_classes = len(seen_classes)

    if i == 0:
        cur_class = [int(o) for o in set(train_y_all)]
        cur_class_batch = [int(o) for o in set(train_y_all)]
        classifier_model.cur_j = examples_per_class(train_y_all)
    else:
        cur_class = [int(o) for o in set(train_y_all).union(set(rm_classifier[1].tolist()))]
        cur_class_batch = [int(o) for o in set(train_y_all)]
        classifier_model.cur_j = examples_per_class(list(train_y_all))

    if i > 0:
        train_batch, it_x_ep = pad_dataset(train_batch, mb_size, shuffle_dataset=True)
        print("Padded dataset has {} patterns".format(len(train_batch)))
    else:
        print("train_y shape: {}".format(train_y_all.shape))

    # *************************************************************************************************************
    # *************************************** CLASSIFIER TRAINING *************************************************
    # *************************************************************************************************************
    print("************************ CLASSIFIER TRAINING ****************************************************")

    if i > 0:
        reset_weights(classifier_model, cur_class_batch)
    classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

    cur_ep_classifier = 0

    acc_classifier = None
    ave_loss_classifier = 0

    if i == 0:
        train_ep_classifier = init_train_ep_classifier
    else:
        train_ep_classifier = inc_train_ep_classifier

    lr_multiplier = -(0.9 / (1 + math.pow(math.e, -((1.5 * i) - 8)))) + 1
    classifier_optimizer.param_groups[0]['lr'] = inc_lr * lr_multiplier

    for ep in range(train_ep_classifier):
        if i == 0:
            # Skip training for first batch
            print('Skipping training for the first batch')
            break

        if ep == (train_ep_classifier // 2):
            for group in classifier_optimizer.param_groups:
                group['lr'] = group['lr'] * 0.1

        print("training ep: ", ep)
        print("lr = {}".format([group['lr'] for group in classifier_optimizer.param_groups]))
        correct_cnt, ave_loss_classifier = 0, 0

        if i > 0:
            cur_sz = 120
            it_x_ep = len(train_batch) // cur_sz
            n2inject = max(0, mb_size - cur_sz)
        else:
            n2inject = 0

        print("total sz:", len(train_batch) + rm_classificator_sz)
        print("n2inject", n2inject)

        train_loader = DataLoader(train_batch, batch_size=(mb_size - n2inject), shuffle=True,
                                  pin_memory=True, drop_last=True, num_workers=8)

        print("it x ep: ", len(train_loader))

        for it, (train_x, train_y) in enumerate(train_loader):
            if ewc_lambda != 0:
                pre_update(classifier_model, synData, classification_layer='fc')

            start = it * (mb_size - n2inject)
            end = (it + 1) * (mb_size - n2inject)

            classifier_optimizer.zero_grad()

            y_mb = train_y

            if n2inject > 0:
                lat_mb_x = np.take(rm_classifier[0], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                lat_mb_y = np.take(rm_classifier[1], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                y_mb = torch.cat((train_y, lat_mb_y), 0)

            if n2inject > 0:
                lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)
            x_mb = maybe_cuda(train_x, use_cuda=use_cuda)
            y_mb = maybe_cuda(y_mb, use_cuda=use_cuda)

            if i > 0:
                classifier_model.train()
                if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
                    raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
            else:
                classifier_model.train()

            if n2inject > 0:
                logits = classifier_model(x_mb, lat_mb_x, latent_layer=latent_layer, return_lat_acts=False)
            else:
                logits = classifier_model(x_mb)
            loss = ce_loss_no_reduction(logits, y_mb)
            ave_loss_classifier += loss.mean().item()

            if i > 0:
                classifier_model.train()
                if not freeze_up_to(classifier_model, "fc", only_conv=False):  # Sets .eval() too
                    raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
                loss[mb_size - n2inject: mb_size].backward(
                    maybe_cuda(torch.full((n2inject,), 1 / mb_size), use_cuda=use_cuda),
                    retain_graph=True
                )

            # reset the required gradient for paramenters
            for name, p in classifier_model.named_parameters():
                p.requires_grad = True

            classifier_model.train()
            if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
                raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
            loss[0: mb_size - n2inject].backward(
                maybe_cuda(torch.full((mb_size - n2inject,), 1 / mb_size), use_cuda=use_cuda)
            )
            _, pred_label = torch.max(logits[0: mb_size - n2inject], 1)
            correct_cnt += (pred_label == y_mb[0: mb_size - n2inject]).sum()

            if ewc_lambda != 0:
                loss += compute_ewc_loss(classifier_model, ewcData, lambd=ewc_lambda,
                                         classification_layer='fc')

            classifier_optimizer.step()

            if ewc_lambda != 0:
                post_update(classifier_model, synData, classification_layer='fc')

            acc_classifier = correct_cnt.item() / \
                             ((it + 1) * y_mb[0: mb_size - n2inject].size(0))

            if it % 10 == 0:
                print(
                    '==>>> it: {}, avg. loss: {:.6f}, '
                    'running train acc: {:.3f}'.format(it, ave_loss_classifier / (it + 1), acc_classifier))
                writer.add_scalar('classifier_train_loss', ave_loss_classifier / (it + 1), tot_it_step_classifier)
                writer.add_scalar('classifier_train_accuracy', acc_classifier, tot_it_step_classifier)
            # Log scalar values (scalar summary) to TB
            tot_it_step_classifier += 1
        cur_ep_classifier += 1

    for group in classifier_optimizer.param_groups:
        group['lr'] = group['lr'] * 10

    w_changes = consolidate_weights(classifier_model, cur_class, cur_class_batch)
    if ewc_lambda != 0:
        update_ewc_data(classifier_model, ewcData, synData, clip_to=0.001, c=wi, classification_layer='fc')

    # update number examples encountered over time
    for c, n in classifier_model.cur_j.items():
        classifier_model.past_j[c] += n
    set_consolidate_weights(classifier_model)

    test_loader = DataLoader(batch_info.get_cumulative_test_set(), batch_size=mb_size, shuffle=False,
                             drop_last=False, num_workers=8, pin_memory=True)

    ave_loss_classifier, acc_classifier, accs, cf = get_accuracy_conf_matrix_from_dataloader(
        classifier_model, ce_loss, test_loader, protocol.n_classes)

    # ------------------ Statistics classifier -------------------------------------
    writer.add_scalar('classifier_test_loss', ave_loss_classifier, i)
    writer.add_scalar('classifier_test_accuracy', acc_classifier, i)

    hit_per_class, tot_test_class = accs

    df_cm = pd.DataFrame(cf, range(protocol.n_classes), range(protocol.n_classes))
    plot_buf = gen_plot(df_cm, max(tot_test_class))
    image_cf = PIL.Image.open(plot_buf)
    image_cf = ToTensor()(image_cf)
    writer.add_image('classifier_confusion_matrix/run_{}'.format(run), image_cf, i)

    print("---------------------------------")
    print("Accuracy classifier: ", acc_classifier)
    print("---------------------------------")

    for past_task in range(i + 1):
        tot_hit = sum(hit_per_class[past_task * 40: (past_task * 40) + 40])
        tot_class_task = sum(tot_test_class[past_task * 40: (past_task * 40) + 40])
        acc_task = tot_hit / tot_class_task
        writer.add_scalar('task_accuracy/task_{}'.format(past_task), acc_task, i)

    for fut_task in range(i + 1, 25):
        writer.add_scalar('task_accuracy/task_{}'.format(fut_task), 0.0, i)

    if ewc_lambda != 0:
        weight_stats(classifier_model, ewcData, clip_to=0.001, classification_layer='fc')

    # *************************************************************************************************************
    # *************************************** REPLAY MEMORY CALCULATION *******************************************
    # *************************************************************************************************************

    print("************************ MEMORY CALCULATION ****************************************************")

    # --------------------------- Classificator replay memory ---------------------------------
    if rm_classificator_sz > 0:
        list_memory = []
        # generate latent features
        new_labels = np.random.choice(len(seen_classes), rm_classificator_sz)
        new_labels = torch.from_numpy(new_labels).type(torch.long)
        nr_batches_memory = rm_classificator_sz // 128
        if rm_classificator_sz % 128 != 0:
            nr_batches_memory = nr_batches_memory + 1
        for y in range(nr_batches_memory):
            batch_dim = min(128, (rm_classificator_sz - (y * 128)))
            samples = np.random.uniform(0.0, 1.8069623, (batch_dim, 64 * 56 * 56))  # 90th percentile layer2
            samples = torch.from_numpy(samples).type(torch.float)
            list_memory.append(samples.reshape((-1, 64, 56, 56)))
        generated_features = torch.cat(list_memory, dim=0)

        rm_classifier = [generated_features, new_labels]
    else:
        rm_classifier = [torch.as_tensor([], dtype=torch.float), torch.as_tensor([], dtype=torch.long)]

    # ########################  END OF BATCH  #################################
# ########################  END OF RUN  #################################
writer.close()
