import numpy as np
import torch
import wandb

from utils.config import cfg
from sklearn.metrics import roc_auc_score, average_precision_score
from tqdm import tqdm

def train_epoch(train_loader, model, device, optimiser, loss_fn, scheduler):
    model.train()

    total_loss = 0

    for step, batch in enumerate(tqdm(train_loader, desc="Iteration")):
        batch = batch.to(device)

        pred = model(batch)
        optimiser.zero_grad()

        is_labeled = batch.y == batch.y

        if cfg.metric == 'ACC':
            loss = loss_fn(pred.to(torch.float32)[is_labeled], batch.y[is_labeled])
        else:
            loss = loss_fn(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])

        total_loss += loss

        loss.backward()
        optimiser.step()

    if scheduler is not None:
        scheduler.step(total_loss)
        
    if cfg.wandb.use:
        wandb.log({
            'loss': total_loss,
        })
    

def eval_epoch(loader, model, device):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch)

        if cfg.metric == 'ACC':
            y_true.append(batch.y.view(-1,1).detach().cpu())
            y_pred.append(torch.argmax(pred.detach(), dim = 1).view(-1,1).cpu())
        else:
            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim = 0).numpy()
    y_pred = torch.cat(y_pred, dim = 0).numpy()

    return compute_metric(y_true, y_pred)

def compute_metric(y_true, y_pred):
    if cfg.metric == 'ROCAUC':
        return _eval_rocauc(y_true, y_pred)
    elif cfg.metric == 'AP':
        return _eval_ap(y_true, y_pred)
    elif cfg.metric == 'ACC':
        return _eval_acc(y_true, y_pred)
    else:
        raise ValueError('compute error')

# Evaluation functions from OGB. 
def _eval_rocauc(y_true, y_pred):
    '''
        compute ROC-AUC averaged across tasks
    '''

    rocauc_list = []

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            rocauc_list.append(
                roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]))

    if len(rocauc_list) == 0:
        raise RuntimeError(
            'No positively labeled data available. Cannot compute ROC-AUC.')
    
    result = sum(rocauc_list) / len(rocauc_list)

    return result

def _eval_ap(y_true, y_pred):
    '''
        compute Average Precision (AP) averaged across tasks
    '''

    ap_list = []

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            # ignore nan values
            is_labeled = y_true[:, i] == y_true[:, i]
            ap = average_precision_score(y_true[is_labeled, i],
                                         y_pred[is_labeled, i])

            ap_list.append(ap)

    if len(ap_list) == 0:
        raise RuntimeError(
            'No positively labeled data available. Cannot compute Average Precision.')
    
    result = sum(ap_list) / len(ap_list)

    return result

def _eval_acc(y_true, y_pred):
    acc_list = []

    for i in range(y_true.shape[1]):
        is_labeled = y_true[:,i] == y_true[:,i]
        correct = y_true[is_labeled,i] == y_pred[is_labeled,i]
        acc_list.append(float(np.sum(correct))/len(correct))

    result = sum(acc_list)/len(acc_list)

    return result
