import torch
import mobilenetv2

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

DATASET_DIR = '/home/xxx/DBQ/'
train_set = CIFAR10(DATASET_DIR, train=True, download=True,
                    transform=transforms.Compose([
                        #transforms.RandomCrop(32, padding=4),
                        #transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                    ]))
test_set = CIFAR10(DATASET_DIR, train=False, download=True,
                   transform=transforms.Compose([
                       #transforms.RandomCrop(32, padding=4),
                       #transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                   ]))
train_loader = DataLoader(train_set, batch_size=128, num_workers=4)
val_loader = DataLoader(test_set, batch_size=128, num_workers=4)

model = mobilenetv2.MobileNetV2(num_classes=10).cuda().train()
state_dict = torch.load('/home/xxx/pytorch-cifar/checkpoint/ckpt_noaug_37.pth', map_location='cpu')['net']
state_dict_ = {k[7:]: p for k, p in state_dict.items() if 'linear.weight' not in k}
model.load_state_dict(state_dict_, strict=False)
model.linear.weight.data.copy_(torch.cat([state_dict['module.linear.weight'], state_dict['module.linear.bias'][:,None]], axis=1))

inputs = []
with torch.no_grad():
    for i, (X, y) in enumerate(tqdm(train_loader)):
        patches = F.unfold(X.cuda(non_blocking=True), 3, padding=1).permute(0,2,1)
        inputs.append(patches.reshape(-1, 1024, 27))

inputs = torch.cat(inputs)

U, S, V = torch.pca_lowrank(inputs.reshape(-1, 27)[:20000], 27)
del U, S


# set running statistics
def training_perf():
    model.train()
    loss_accum = 0
    correct_count = 0
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    for X, y in tqdm(train_loader):
        X = X.cuda(non_blocking=True)
        y = y.cuda(non_blocking=True)
        logits = model(X)
        loss = criterion(logits, y)
        loss_accum += loss.data
        correct_count += (torch.argmax(logits, axis=1) == y).sum()
    print((loss_accum / len(train_set)).item())
    print(f'{(correct_count / len(train_set)).item():.4f}')


training_perf()

