import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
from torchmetrics.functional import f1_score as f1_score_torch
from torchmetrics.functional import auroc, accuracy


def weighted_accuracy(test_preds_emo, test_truth_emo):
    true_label = (test_truth_emo > 0)
    predicted_label = (test_preds_emo > 0)
    tp = float(np.sum((true_label == 1) & (predicted_label == 1)))
    tn = float(np.sum((true_label == 0) & (predicted_label == 0)))
    p = float(np.sum(true_label == 1))
    n = float(np.sum(true_label == 0))

    return (tp * (n / p) + tn) / (2 * n)

def calculate_accuracy(outputs, targets, logger=None):
    with torch.no_grad():
        batch_size = targets.size(0)

        _, pred = outputs.topk(1, 1, largest=True, sorted=True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1))
        n_correct_elems = correct.float().sum().item()
        acc = n_correct_elems / batch_size
        
        f1_macro = f1_score_torch(pred.squeeze(), targets.long(), average='macro', num_classes=101, multiclass=True)
        if logger is not None:
            logger.add_log("accuracy", acc)
            logger.add_log("f1_macro", f1_macro)
            logger.write_log(0)

        return {'accuracy': torch.tensor(acc), 'f1_macro': f1_macro}

def calculate_auroc(outputs, targets, logger=None):
    with torch.no_grad():
        if len(outputs.shape) == 2:
            all_logits = torch.softmax(outputs, dim=-1)
            auroc_score = auroc(all_logits, targets.long(), average='weighted', num_classes=2)
            acc = accuracy(all_logits, targets.long(), average='weighted', num_classes=2)
            f1_macro = f1_score_torch(all_logits, targets.long(), average='macro', num_classes=2)
        else:
            all_logits = outputs.float()
            auroc_score = auroc(all_logits, targets.long(), average='micro')
            acc = accuracy(all_logits, targets.long(), average='micro')
            f1_macro = f1_score_torch(all_logits, targets.long(), average='micro')

        if logger is not None:
            logger.add_log("auroc", auroc_score)
            logger.add_log("accuracy", acc)
            logger.add_log("f1_macro", f1_macro)
            logger.write_log(0)

        return {'auroc': auroc_score, 'accuracy': acc, 'f1_macro': f1_macro}

def calculate_f1(outputs, targets, logger=None):
    all_logits = torch.sigmoid(outputs)
    
    f1_micro = f1_score_torch(all_logits, targets.long(), average='micro', num_classes=23)
    f1_macro = f1_score_torch(all_logits, targets.long(), average='macro', num_classes=23)
    f1_samples = f1_score_torch(all_logits, targets.long(), average='none', num_classes=23)
    f1_weighted = f1_score_torch(all_logits, targets.long(), average='weighted', num_classes=23)

    if logger is not None:
        logger.add_log("f1_macro", f1_macro)
        logger.add_log("f1_micro", f1_micro)
        logger.add_log("f1_weighted", f1_weighted)
        logger.add_log("f1_samples", f1_samples)
        logger.write_log(0)
    
    return {'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'f1_weighted': f1_weighted,
            'f1_samples': f1_samples}

def eval_mosei(results, truths, logger=None, exclude_zero=False, classification=True):
    if classification:
        results = torch.argmax(results, dim=1).float()
    
    test_preds = results.view(-1).cpu().detach().numpy()
    test_truth = truths.view(-1).cpu().detach().numpy()

    non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)])
    mae = np.mean(np.absolute(test_preds - test_truth))  # Average L1 distance between preds and truths
    
    if torch.var(results.view(-1).float()) == 0 or torch.var(truths.view(-1).float()) == 0:
        corr = -100.
    else:
        corr = np.corrcoef(test_preds, test_truth)[0][1]
    if classification:
        f_score = f1_score((test_preds[non_zeros] > 3), (test_truth[non_zeros] > 3), average='weighted')
        acc = accuracy_score((test_truth[non_zeros] > 3), (test_preds[non_zeros] > 3))
    else:
        f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted')
        acc = accuracy_score((test_truth[non_zeros] > 0), (test_preds[non_zeros] > 0))

    # Log results
    if logger is not None:
        logger.add_log("mae", mae)
        logger.add_log("correlation", corr)
        logger.add_log("f1_score", f_score)
        logger.add_log("accuracy", acc)
        logger.write_log(0)

    # return dict of results
    return {'mae': torch.tensor(mae), 
            'correlation': torch.tensor(corr), 
            'f1_score': torch.tensor(f_score), 
            'accuracy': torch.tensor(acc)}


def eval_mosi(results, truths, logger=None, exclude_zero=False):
    return eval_mosei(results, truths, logger, exclude_zero)


def eval_iemocap(results, truths, logger, single=-1):
    emos = ["Neutral", "Happy", "Sad", "Angry"]
    if single < 0:
        test_preds = results.view(-1, 4, 2).cpu().detach().numpy()
        test_truth = truths.view(-1, 4).cpu().detach().numpy()

        for emo_ind in range(4):
            print(f"{emos[emo_ind]}: ")
            test_preds_i = np.argmax(test_preds[:, emo_ind], axis=1)
            test_truth_i = test_truth[:, emo_ind]
            f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
            acc = accuracy_score(test_truth_i, test_preds_i)
            print("  - F1 Score: ", f1)
            print("  - Accuracy: ", acc)
    else:
        test_preds = results.view(-1, 2).cpu().detach().numpy()
        test_truth = truths.view(-1).cpu().detach().numpy()

        print(f"{emos[single]}: ")
        test_preds_i = np.argmax(test_preds, axis=1)
        test_truth_i = test_truth
        f1 = f1_score(test_truth_i, test_preds_i, average='weighted')
        acc = accuracy_score(test_truth_i, test_preds_i)
        print("  - F1 Score: ", f1)
        print("  - Accuracy: ", acc)



