import torch
import numpy as np


def fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics=[],
        start_epoch=0):
    """
    Loaders, model, loss function and metrics should work together for a given task,
    i.e. The model should be able to process data output of loaders,
    loss function should process target output of loaders and outputs from the model

    Examples: Classification: batch loader, classification model, NLL loss, accuracy metric
    Siamese network: Siamese loader, siamese model, contrastive loss
    Online triplet learning: batch loader, embedding model, online triplet loss
    """
    if start_epoch > 0:
        for epoch in range(0, start_epoch):
            scheduler.step()

    for epoch in range(start_epoch, n_epochs):

        # Train stage
        train_loss, metrics = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics)
        message = 'Epoch: {}/{}. Train loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)
        
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())

        val_loss, metrics = test_epoch(val_loader, model, loss_fn, cuda, metrics)
        val_loss /= len(val_loader)

        message += ' Validation loss: {:.4f}'.format(val_loss)
        for metric in metrics:
            message += '\t{}: {}'.format(metric.name(), metric.value())

        print(message)
        scheduler.step()

def Logits_Loss(logits, targets):
    loss = 0
    for logit, target in zip(logits, targets):
        loss += torch.nn.functional.cross_entropy(logit, target, reduction='mean')
    return loss

def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics, use_pred_loss=False, verbose=False):
    for metric in metrics:
        metric.reset()

    model.train() # set train mode
    losses = []
    total_loss = 0
    nb_batches = len(train_loader)
    bad_triplets = 0

    for batch_idx, (data, targets) in enumerate(train_loader):
        print("\rbatch {}/{} ".format(batch_idx, nb_batches), end="")
        targets = targets if len(targets) > 0 else None
        if not type(data) in (tuple, list):
            data = (data,)
        if cuda:
            data = tuple(d.cuda() for d in data)
            if targets is not None:
                targets = tuple(t.cuda() for t in targets)

        optimizer.zero_grad()
        outputs, logits = model(*data)

        if type(outputs) not in (tuple, list):
            outputs = (outputs,)

        if use_pred_loss:
            logits_loss = Logits_Loss(logits, targets)

        loss_outputs, triplet_losses = loss_fn(*outputs)
        bad_triplets += len(np.argwhere(triplet_losses.cpu() == 0.)[0])
        
        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        
        if use_pred_loss:
            logits_loss.backward(retain_graph=True)

        loss.backward()
        optimizer.step()

        for metric in metrics:
            metric(outputs, targets, loss_outputs)

        if batch_idx % log_interval == 0 and verbose:
            message = '    Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx * len(data[0]), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), np.mean(losses))
            for metric in metrics:
                message += '\t{}: {}'.format(metric.name(), metric.value())

            print(message)
            losses = []

    total_loss /= (batch_idx + 1)
    bad_triplets = bad_triplets/len(train_loader.dataset)
    return total_loss, metrics, bad_triplets


def test_epoch(val_loader, model, loss_fn, cuda, metrics, use_pred_loss=False):
    with torch.no_grad():
        for metric in metrics:
            metric.reset()
        model.eval() # set evaluation mode
        val_loss = 0
        bad_triplets = 0
        for batch_idx, (data, target) in enumerate(val_loader):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if cuda:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            outputs, logits = model(*data)

            if type(outputs) not in (tuple, list):
                outputs = (outputs,)
            loss_inputs = outputs
            if target is not None:
                target = (target,)
                loss_inputs += target

            loss_outputs, triplet_losses = loss_fn(*loss_inputs)
            bad_triplets += len(np.argwhere(triplet_losses.cpu() == 0.)[0])

            loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
            val_loss += loss.item()

            for metric in metrics:
                metric(outputs, target, loss_outputs)

    bad_triplets = bad_triplets/len(val_loader.dataset)
    return val_loss, metrics, bad_triplets

def test_epoch_pred(val_loader, model, loss_fn, cuda):
    with torch.no_grad():
        model.eval() # set evaluation mode
 
        acc = 0
        nsamples = 0 
        for batch_idx, (data, target) in enumerate(val_loader):
            target = target if len(target) > 0 else None
            if not type(data) in (tuple, list):
                data = (data,)
            if cuda:
                data = tuple(d.cuda() for d in data)
                if target is not None:
                    target = target.cuda()

            logits = model.get_logits(*data)
            _, y_pred = torch.max(logits, 1)
            acc += torch.sum(y_pred == target)
            nsamples += len(target)
        acc = acc * 1./nsamples
    return acc

