import math

import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets
from code.arch_v2 import DigitClassifier



@torch.no_grad()
def eval_classifier(model, x_pos, x_neg):
    bs = 256
    n_correct = 0
    n_total = 0

    nb = math.ceil(len(x_pos) / bs)
    for i in range(nb):
        x = x_pos[i*bs : (i+1)*bs]
        pred = torch.sigmoid(model(x))
        n_correct += (pred > 0.5).sum()
        n_total += len(pred)

    nb = math.ceil(len(x_neg) / bs)
    for i in range(nb):
        x = x_neg[i*bs : (i+1)*bs]
        pred = torch.sigmoid(model(x))
        n_correct += (pred < 0.5).sum()
        n_total += len(pred)

    assert n_total == len(x_pos) + len(x_neg)
    return float(n_correct) / n_total



def train_classifier(digit, num_steps=5000, bs=256, lr=5e-4):
    data_tr = datasets.MNIST('/tmp/compflow/data', train=True, download=True)
    data_te = datasets.MNIST('/tmp/compflow/data', train=False, download=True)
    X_tr = data_tr.data.float().reshape(-1, 1, 28, 28).cuda() / 256.
    Y_tr = data_tr.targets
    X_te = data_te.data.float().reshape(-1, 1, 28, 28).cuda() / 256.
    Y_te = data_te.targets
    assert len(X_tr) == len(Y_tr) == 60000

    X_tr_pos = X_tr[Y_tr == digit]
    X_tr_neg = X_tr[Y_tr != digit]
    X_te_pos = X_te[Y_te == digit]
    X_te_neg = X_te[Y_te != digit]

    # print(f'X_tr_pos: {X_tr_pos.shape}    X_tr_neg: {X_tr_neg.shape}')
    # print(f'X_te_pos: {X_te_pos.shape}    X_te_neg: {X_te_neg.shape}')

    model = DigitClassifier().cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for i in range(1, num_steps+1):
        perm = np.random.choice(len(X_tr_pos), bs//2, replace=False)
        bx_pos, by_pos = X_tr_pos[perm], X_tr_pos.new_ones(bs//2)
        perm = np.random.choice(len(X_tr_neg), bs//2, replace=False)
        bx_neg, by_neg = X_tr_neg[perm], X_tr_neg.new_zeros(bs//2)
        bx = torch.cat([bx_pos, bx_neg], dim=0)
        by = torch.cat([by_pos, by_neg], dim=0)
        logits = model(bx)

        loss = F.binary_cross_entropy_with_logits(logits, by)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'\rstep {i}/{num_steps} '
              f'loss {loss:.3f} ', end='', flush=True)
        
        if i % 1000 == 0 or i == num_steps:
            model.eval()
            acc_tr = eval_classifier(model, X_tr_pos, X_tr_neg)
            acc_te = eval_classifier(model, X_te_pos, X_te_neg)
            print(f'Step {i} / {num_steps}. Train Accuracy: {acc_tr} Test Accuracy: {acc_te}')
            model.train()

    print(f'Training done for digit {digit}')
    return model

if __name__ == '__main__':
    for digit in range(10):
        model = train_classifier(digit=digit)
        model = model.cpu()
        torch.save(model.state_dict(), DigitClassifier.get_ckpt_path(digit))
