from torch.optim import Adam
import numpy as np
import torch
from training.utils import load_model
from training.memory_management import compute_M, erase_LIM, tri_2_square
import os
import time
import sys
import collections
from torch.nn import functional as F
from training import metrics, utils
from models.utils import getModel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# TODO: This is inconsistent with the fact that we specify the device in config file


def train_mini_batch(net,
                     i,
                     optimizer,
                     criterion,
                     trainloader,
                     inputs,
                     labels,
                     num_ens=1,
                     beta_type=0.1,
                     epoch=None,
                     num_epochs=None):
    optimizer.zero_grad()
    outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
    kl = 0.0

    for j in range(num_ens):
        net_out, _kl = net(inputs)
        kl += _kl
        outputs[:, :, j] = F.log_softmax(net_out, dim=1)

    kl = kl / num_ens

    log_outputs = utils.logmeanexp(outputs, dim=2)
    beta = metrics.get_beta(i - 1, len(trainloader), beta_type, epoch,
                            num_epochs)
    loss, loss_nll_only = criterion(log_outputs, labels, kl, beta)
    loss.backward(retain_graph=True)
    optimizer.step()

    return kl.item(), loss.cpu().data.numpy(), metrics.acc(
        log_outputs.data, labels), loss_nll_only.cpu().data.numpy()


def training(cfg, train_dl_list, sim, start_time):
    net = getModel(cfg.net_type, cfg.inputs, cfg.outputs,
                   cfg.dim, cfg.priors, cfg.layer_type,
                   cfg.activation_type, cfg.neurons
                   ).to(cfg.device)

    if not os.path.exists('results'):
        os.mkdir('results')
    if not os.path.exists(str(cfg.folder) + '-' + str(sim)):
        os.mkdir(str(cfg.folder) + '-' + str(sim))

    with open(str(cfg.folder) + '-' + str(sim) + '/time_measurement.csv', 'a') as f:
        f.write('type,step,count,sum_time_in_type,sum_time_from_start\n')
        f.close()


    # init
    n_tasks = len(cfg.tasks_description)
    buffer = []
    M = []
    counter = 0
    hist_losses = collections.deque()
    hist_losses.append(0)
    detection = False

    count_total_learn = 0
    time_total_learn = 0.0


    for task in range(0, n_tasks):
        print('Training task:', task)

        if cfg.transfer_posterior == False:
            net.reset_params()

        criterion = metrics.ELBO(len(train_dl_list[task].dataset)).to(
            cfg.device)
        optimizer = Adam(net.parameters(), lr=cfg.lr)

        for j in range(cfg.n_epochs):  # loop over the dataset multiple times

            training_loss = 0.0
            accs = []
            kl_list = []
            loss_list = []

            for i, (inputs, labels) in enumerate(train_dl_list[task], 1):
                inputs, labels = inputs.to(device), labels.to(device)
                start_time_learn = time.time()

                kl, loss, acc, loss_nll_only = train_mini_batch(
                    net, i, optimizer, criterion, train_dl_list[task], inputs,
                    labels)
                # print('loss',loss_nll_only)
                kl_list.append(kl)
                training_loss += loss_nll_only
                loss_list.append(loss_nll_only)
                accs.append(acc)

                count_total_learn += 1
                time_total_learn += time.time() - start_time_learn

                with open(str(cfg.folder) + '-' + str(sim) + '/time_measurement.csv', 'a') as f:
                    # 'type,step,count,sum time in type,sum time from start\n'
                    f.write('train,' + str(i) + ',' + str(count_total_learn) +
                            ',' + str(time_total_learn) +
                            ',' + str(time.time() - start_time) + '\n')
                    f.close()
                # print(acc)

            if (task == 0 and j == 0) is False:
                previous_train_loss = train_loss
                hist_losses.append(previous_train_loss)
                thres = 6  # hyper parameter, window size

                if len(hist_losses) > thres:
                    hist_losses.popleft()

            train_loss, train_acc, train_kl = training_loss / len(
                train_dl_list[task]), np.mean(accs), np.mean(kl_list)
            print('train loss',train_loss, flush=True)
            print('train acc',train_acc, flush=True)


            if (task == 0 and j == 0) is False:
                try:
                    if cfg.n_epochs == 1:
                        detection = (train_loss - previous_train_loss >
                                     np.std(hist_losses))
                    else:
                        tau = np.std(hist_losses)
                        detection = (train_loss - previous_train_loss >
                                     tau) and len(hist_losses) > 2
                except:
                    pass
                if previous_task != task: # temporary
                #if detection:
                    load_model(net, str(cfg.folder) + '-' + str(sim) + '/previous.save')
                    torch.save(
                        net.state_dict(),
                        str(cfg.folder) + '-' + str(sim) + '/model-' + str(counter) + '.save')
                    buffer.append(
                        str(cfg.folder) + '-' + str(sim) + '/model-' + str(counter) + '.save')
                    counter += 1

                    # flush the loss record
                    hist_losses = collections.deque()
                    hist_losses.append(0)

                    if len(buffer) > 1:
                        # compute M
                        M = compute_M(inputs, buffer, M, cfg)


                    if len(buffer) == cfg.capacity + 1:

                        indice_least_important_model, M = erase_LIM(M)
                        buffer.pop(indice_least_important_model)

                    if cfg.transfer_posterior:
                        net.reset_priors_networks()



            previous_task = task
            torch.save(net.state_dict(), str(cfg.folder) + '-' + str(sim) + '/previous.save')


    # save the last model
    torch.save(net.state_dict(),
               str(cfg.folder) + '-' + str(sim) + '/model-' + str(counter) + '.save')
    buffer.append(str(cfg.folder) + '-' + str(sim) + '/model-' + str(counter) + '.save')
    counter += 1

    if len(buffer) == cfg.capacity + 1:
        # compute M remove last model
        M = compute_M(inputs, buffer, M, cfg)
        indice_least_important_model, M = erase_LIM(M)
        buffer.pop(indice_least_important_model)

    return buffer
