import torch
import torch.nn as nn


device = "cuda" if torch.cuda.is_available() else "cpu"


class LinearComparison(nn.Module):
    def __init__(self, input_dim):
        super(LinearComparison, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    # Just return softmax
    def forward(self, x):
        out = self.linear(x).squeeze(-1)
        return nn.Softmax(0)(out)

    def score(self, x):
        out = self.linear(x).squeeze(-1)
        return out


def proportional_nllloss(pred, y):
    # gi * softmax where gi = % of correct programs in cluster

    non_zero_clusters = torch.where((pred * y) > 0, (pred * y), 0)

    return -torch.log(torch.sum(non_zero_clusters))


def train(training_X, training_y, model, optimizer, loss_fn=proportional_nllloss, l1reg_lmbda=None):
    model.train()

    avg_loss = 0
    num_trained = 0
    for (X, y) in zip(training_X, training_y):
        # skip if cluster doesn't have a single good program
        if not (y > 0).any():
            continue

        # Compute prediction error - have to unsqueeze to turn into a matrix w/ 1 row and N columns
        pred = torch.unsqueeze(model(X.float()), 0)

        loss = loss_fn(pred, y)

        # l1 regularization
        if l1reg_lmbda:
            loss += torch.abs(model.linear.weight).sum() * l1reg_lmbda

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        num_trained += 1
        avg_loss += loss.item()

    return avg_loss/num_trained


def test(test_X, test_y, model, loss_fn=nn.NLLLoss()):
    model.eval()
    test_loss, correct = 0, 0
    total = len(test_X)

    total_with_actual_cluster = 0

    with torch.no_grad():
        for X, y in zip(test_X, test_y):
            X, y = X.to(device), y.to(device)

            pred = model(X.float())
            if y.item() != -1:
                test_loss += loss_fn(torch.unsqueeze(pred, 0), y).item()
                total_with_actual_cluster += 1

            # print(pred, pred.argmax(0))
            correct += (pred.argmax(0) == y).type(torch.float).sum().item()
    test_loss /= total_with_actual_cluster
    correct /= total
    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss


def test_thresholded(eval_X, eval_y, model):
    model.eval()

    correct = 0
    num_skipped = 0
    total = len(eval_X)

    for X, y in zip(eval_X, eval_y):
        X, y = X.to(device), y.to(device)

        pred = model(X)

        # print(pred, pred.argmax(0))
        correct += (pred.argmax(0) == y).type(torch.float).sum().item()

    correct /= total

    from_non_skipped = correct / \
        (total - num_skipped) if num_skipped != total else 1

    return (correct, from_non_skipped, num_skipped)
