import os
import torch
import pickle
import numpy as np
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from adaptive import StaticMaskLayer2d
from resnet import ResNet18Backbone, ResNet18ClassifierHead
from adaptive import BaseModel
import argparse

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)

def generate_center_mask(mask_width, crop_width):
    assert crop_width <= mask_width
    mask = np.zeros((mask_width, mask_width))
    start_idx = mask_width//2-(cropwidth//2)
    mask[start_idx:start_idx+crop_width,start_idx:start_idx+crop_width] = 1
    return mask

if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Transformations
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Determine train/val split
    np.random.seed(0)
    val_inds = np.sort(np.random.choice(50000, size=10000, replace=False))
    train_inds = np.setdiff1d(np.arange(50000), val_inds)

    # Training dataset
    dataset = CIFAR10('/tmp/cifar/', download=True, train=True, transform=transform_train)
    train_dataset = torch.utils.data.Subset(dataset, train_inds)

    # Validation dataset
    dataset = CIFAR10('/tmp/cifar/', download=True, train=True, transform=transform_test)
    val_dataset = torch.utils.data.Subset(dataset, val_inds)

    # Test dataset
    test_dataset = CIFAR10('/tmp/cifar/', download=True, train=False, transform=transform_test)

    # Set input/output dimensions
    d_in = 32 * 32
    d_out = 10

    # Set up train and val loader
    mbsize = 128  # Takes ~8min per epoch, takes ~2300Mb of GPU memory
    # mbsize = 512  # Inexplicably, each epoch now takes ~10min and GPU memory is just 4200
    train_loader = DataLoader(train_dataset, batch_size=mbsize, shuffle=True,
                            pin_memory=True, drop_last=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=mbsize, shuffle=False,
                            pin_memory=True, drop_last=False, num_workers=4)
    
    # Number of features to select
    max_cropwidth = (1, 2, 3, 4, 5, 6, 7)

    device = torch.device('cuda', args.gpu)
    results_dict = {
        'acc': {cropwidth*cropwidth:[] for cropwidth in max_cropwidth}
    }
    for cropwidth in max_cropwidth:
        print(cropwidth*cropwidth)
        mask = torch.tensor(generate_center_mask(mask_width=8, crop_width=cropwidth)).float().to(device)
        mask_layer = StaticMaskLayer2d(mask=mask, mask_width=8, patch_size=4)
        best_model = None
        best_loss = np.inf
        for _ in range(args.num_restarts):
            # Set up model
            # Shared backbone
            backbone = ResNet18Backbone()
            # Classifier head
            classifier_head = ResNet18ClassifierHead()
            # Mask layer
            model = nn.Sequential(mask_layer, backbone, classifier_head)
            trainer = BaseModel(model).to(device)

            trainer.fit(train_dataset,
                        val_dataset,
                        mbsize=128,
                        lr=1e-3,
                        nepochs=100,
                        loss_fn=nn.CrossEntropyLoss(),
                        verbose=False)
            
            val_loss = trainer.evaluate(val_dataset, Accuracy(), 1024)
            if val_loss < best_loss:
                best_model = trainer
                best_loss = val_loss
        # Get best model
        trainer = best_model

        # Calculate test set performance
        test_acc = trainer.evaluate(test_dataset, Accuracy(), 1024)
        print(f'Acc = {100*test_acc:.2f}\n')
        results_dict['acc'][cropwidth*cropwidth] = test_acc

    with open('results/center_crop_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
