import numpy as np
import torch
from torch.nn import functional as F
import time
from models.utils import getModel


def get_models_in_buffer(cfg, buffer):
    net_list = []
    for i in range(0, len(buffer)):
        net = getModel(cfg.net_type, cfg.inputs, cfg.outputs,
                       cfg.dim, cfg.priors, cfg.layer_type,
                       cfg.activation_type, cfg.neurons)

        net.load_state_dict(torch.load(buffer[i], map_location=torch.device(cfg.device)))
        net.to(cfg.device)
        net.eval()
        net_list.append(net)

    return net_list


def evaluation(cfg, test_dl_list, buffer, sim, start_time):
    time_start_load_model = time.time()
    net_list = get_models_in_buffer(cfg, buffer)
    time_load_model = time.time() - time_start_load_model
    start_time += time_load_model  # do not consider the time for loading model because the model is kept in memory in the baseline code

    accs = []
    eval_total_compute_time = 0.0
    for task_number in range(0, len(cfg.tasks_description)):

        dl = test_dl_list[task_number]
        # y_avg_list = []
        # label_list = []
        correct_count = 0
        total_count = 0
        for batch_id, data in enumerate(dl):

            images, labels = data
            images = images.to(cfg.device)
            labels = labels.to(cfg.device)
            # label_list.extend(labels)

            start_time_eval = time.time()

            predictions_all_models = []
            var_models = []

            for i in range(0, len(buffer)):

                y = soft_predictions(images, 10, net_list[i]).detach()  # Note: It's important to have detach() here
                #print(y.shape)
                var_per_batch_single_model = torch.mean(torch.std(y, dim=0, unbiased=False))

                y_avg = torch.mean(y, dim=0)

                predictions_all_models.append(y_avg)
                var_models.append(var_per_batch_single_model)

            predictions_all_models = torch.stack(predictions_all_models)
            var_models = torch.stack(var_models)

            sh0, sh1, sh2 = predictions_all_models.size()
            indice_model_to_keep = torch.argmin(var_models, dim=0)
            #print(indice_model_to_keep)

            # y_avg_list.extend(predictions_all_models[indice_model_to_keep])
            correct_count += (torch.argmax(predictions_all_models[indice_model_to_keep], dim=1) == labels).int().sum().item()
            total_count += y_avg.size()[0]

            eval_total_compute_time += time.time() - start_time_eval

        # y_avg_list = torch.stack(y_avg_list)
        # label_list = torch.stack(label_list)
        # acc = eval_accuracy(y_avg_list, label_list).detach().cpu().numpy()
        acc = float(correct_count)/total_count
        accs.append(acc)
        #print('task', task_number, "overall accuracy:", acc)

    accs.append(np.mean(accs))
    #print('eval_total_compute_time:', eval_total_compute_time)

    with open(str(cfg.folder) + '-' + str(sim) + '/time_measurement.csv', 'a') as f:
        # 'type,step,count,sum time in type,sum time from start\n'
        f.write('eval,' + str(-1) + ',' + str(1) +
                ',' + str(eval_total_compute_time) +
                ',' + str(time.time() - start_time) + '\n')
        f.close()

    return accs


def soft_predictions(images, num_samples, net):
    yhats = torch.stack([
        torch.exp(F.log_softmax(net(images)[0]))
        for i in range(0, num_samples)
    ])

    return yhats


# def eval_accuracy(prob, labels):
#     assert (prob.size()[0] == labels.size()[0])
#
#     equal = (torch.argmax(prob, dim=1) == labels).float()
#     accuracy = torch.sum(equal) / labels.size()[0]

    # return accuracy
