import logging
import os
import numpy as np
import torch
from tqdm import tqdm
import warnings
from typing import Tuple

import tools.utils as utils

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from tools.utils import RunPhase

logger = logging.getLogger(__name__)
warnings.filterwarnings('ignore')


def get_outputs_labels(model, loss_fn, dataloader, params, run_phase):
    model.eval()

    output_all = []
    labels_all = []
    name_all = []
    valid_loss = 0

    for batch in tqdm(dataloader, total=len(dataloader)):
        input_batch = batch['input'].type(torch.FloatTensor)
        label_batch = batch['label'].type(torch.FloatTensor)
        label_batch = label_batch.argmax(-1)

        # move to GPU if available
        if params['cuda']:
            input_batch = input_batch.cuda(non_blocking=True)
            label_batch = label_batch.cuda(non_blocking=True)

        output_batch_list = []
        output_batch_before_act_list = []
        for i in range(len(params['preproc']['eval_cut_section_sec'])):
            with torch.no_grad():
                output_batch_before_act = model(input_batch[:, i])
                output_batch = torch.softmax(output_batch_before_act, -1)
                output_batch_before_act_list.append(output_batch_before_act)
                output_batch_list.append(output_batch)

        output_batch = sum(output_batch_list) / len(params['preproc']['eval_cut_section_sec'])
        output_batch_before_act = sum(output_batch_before_act_list) / len(params['preproc']['eval_cut_section_sec'])

        if run_phase == RunPhase.TRAIN:
            valid_loss += loss_fn(output_batch_before_act, label_batch) * len(output_batch_before_act)
        # extract data from torch Variable, move to cpu, convert to numpy arrays
        output_batch = output_batch.data.cpu().numpy()
        label_batch = label_batch.data.cpu().numpy()

        output_all.append(output_batch)
        labels_all.append(label_batch)
        name_all.append(batch['fname'])

    output = np.concatenate(output_all, axis=0)
    labels = np.concatenate(labels_all, axis=0)
    names = np.concatenate(name_all, axis=0)
    valid_loss = valid_loss / len(dataloader.dataset)

    return output, labels, names, valid_loss


def get_summary(output: np.array,
                labels: np.array) -> Tuple[dict, str]:
    summary = dict()
    metrics_string = ''
    summary['Acc'] = accuracy_score(labels, np.argmax(output, axis=1))
    summary['F1'] = f1_score(labels, np.argmax(output, axis=1), average='macro')
    summary['AUC'] = roc_auc_score(labels, output, average='macro', multi_class='ovr')
    metrics_string += "Acc: {:05.4f}".format(summary['Acc']) + '\n'
    metrics_string += "F1: {:05.4f}".format(summary['F1']) + '\n'
    metrics_string += "AUC: {:05.4f}".format(summary['AUC']) + '\n'

    return summary, metrics_string


def evaluate(model, loss_fn, dataloader, params, run_phase, test_save_dir="", model_saved_name=""):

    output, labels, names, valid_loss = get_outputs_labels(model, loss_fn, dataloader, params, run_phase)
    summary, metrics_string = get_summary(output, labels)
    if test_save_dir != "" and model_saved_name != "":  # do when inference not validation
        save_path = os.path.join(test_save_dir, model_saved_name)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        # write inference_summary as csv
        utils.wirte_log_csv(path=save_path, name='inference_summary.csv', summary_dict=summary)

    return valid_loss, summary, metrics_string, output, labels, names
