import numpy as np
import torch
import os
from operator import truediv
import time
import collections

from sklearn import metrics, preprocessing
import data_utils as du

def evaluate_accuracy_new(data_iter, net, loss, device):
    # acc_sum, n = 0.0, 0

    val_loss = du.Averager()
    val_acc = du.Averager()

    preds = []
    gt = []
    with torch.no_grad():
        for X, y in data_iter:
            # test_l_sum, test_num = 0, 0
            #X = X.permute(0, 3, 1, 2)
            X = X.to(device)
            y = y.to(device)
            net.eval() 
            y_hat = net(X)
            l = loss(y_hat, y.long())

            val_acc.add(
                    v = (y_hat.argmax(dim=1) == y.to(device)).float().mean().cpu().item(),
                    n = y.shape[0]
                )
            val_loss.add(l.cpu().item())

            preds.append(y_hat.argmax(dim=1).cpu().numpy())
            gt.append(y.cpu().numpy())

            # acc_sum += (y_hat.argmax(dim=1) == y.to(device)).float().mean().cpu().item()
            # test_l_sum += l
            # test_num += 1
            net.train() 
            # n += y.shape[0]
    # return [acc_sum / n, test_l_sum] # / test_num]
    preds = np.concatenate(preds, axis = 0)
    gt = np.concatenate(gt, axis = 0)
    return val_acc.item(), val_loss.item(), preds, gt

def evaluate_accuracy(data_iter, net, loss, device):
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            test_l_sum, test_num = 0, 0
            #X = X.permute(0, 3, 1, 2)
            X = X.to(device)
            y = y.to(device)
            net.eval() 
            y_hat = net(X)
            l = loss(y_hat, y.long())
            acc_sum += (y_hat.argmax(dim=1) == y.to(device)).float().sum().cpu().item()
            test_l_sum += l
            test_num += 1
            net.train() 
            n += y.shape[0]
    return [acc_sum / n, test_l_sum] # / test_num]


def aa_and_each_accuracy(confusion_matrix):
    list_diag = np.diag(confusion_matrix)
    list_raw_sum = np.sum(confusion_matrix, axis=1)
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
    average_acc = np.mean(each_acc)
    return each_acc, average_acc



def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, training_time_ae, testing_time_ae, path):
    f = open(path, 'a')
    sentence0 = 'OAs for each iteration are:' + str(oa_ae) + '\n'
    f.write(sentence0)
    sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n'
    f.write(sentence1)
    sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n'
    f.write(sentence2)
    sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)) + ' ± ' + str(np.std(oa_ae)) + '\n'
    f.write(sentence3)
    sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)) + ' ± ' + str(np.std(aa_ae)) + '\n'
    f.write(sentence4)
    sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n'
    f.write(sentence5)
    sentence6 = 'Total average Training time is: ' + str(np.sum(training_time_ae)) + '\n'
    f.write(sentence6)
    sentence7 = 'Total average Testing time is: ' + str(np.sum(testing_time_ae)) + '\n' + '\n'
    f.write(sentence7)
    element_mean = np.mean(element_acc_ae, axis=0)
    element_std = np.std(element_acc_ae, axis=0)
    sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean) + '\n'
    f.write(sentence8)
    sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std) + '\n'
    f.write(sentence9)
    f.close()



def eval_data(test_iter, net, gt, device):
    pred_test = []
    tic2 = time.time()
    with torch.no_grad():
        for X, y in test_iter:
            # print('Shape of X', X.shape, 'Shape of y', y.shape)
            # X = X.permute(0, 3, 1, 2)
            X = X.to(device)
            net.eval()
            y_hat = net(X)
            pred_test.extend(np.array(y_hat.cpu().argmax(axis=1)))
    toc2 = time.time()
    collections.Counter(pred_test)
    # gt_test = gt[test_indices] - 1

    overall_acc = metrics.accuracy_score(pred_test, gt)
    confusion_matrix = metrics.confusion_matrix(pred_test, gt)
    each_acc, average_acc = aa_and_each_accuracy(confusion_matrix)
    kappa = metrics.cohen_kappa_score(pred_test, gt)

    return overall_acc, each_acc, average_acc, confusion_matrix, kappa
    