import os
import torch
import pickle
import numpy as np
import torch.nn as nn
from torchmetrics import Accuracy
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from adaptive import StaticMaskLayer2d
from resnet import ResNet18Backbone, ResNet18ClassifierHead
from adaptive import BaseModel
import argparse

# Set up command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--num_restarts', type=int, default=1)

def generate_random_mask(mask_width, num_features):
    assert num_features <= mask_width**2
    mask = np.zeros(mask_width**2)
    idx = np.random.choice(mask_width**2, num_features, replace=False)
    mask[idx] = 1
    mask = mask.reshape(mask_width, mask_width)
    return mask

if __name__ == '__main__':
    # Parse args
    args = parser.parse_args()

    # Transformations
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Determine train/val split
    np.random.seed(0)
    val_inds = np.sort(np.random.choice(50000, size=10000, replace=False))
    train_inds = np.setdiff1d(np.arange(50000), val_inds)

    # Training dataset
    dataset = CIFAR10('/tmp/cifar/', download=True, train=True, transform=transform_train)
    train_dataset = torch.utils.data.Subset(dataset, train_inds)

    # Validation dataset
    dataset = CIFAR10('/tmp/cifar/', download=True, train=True, transform=transform_test)
    val_dataset = torch.utils.data.Subset(dataset, val_inds)

    # Test dataset
    test_dataset = CIFAR10('/tmp/cifar/', download=True, train=False, transform=transform_test)

    # Set input/output dimensions
    d_in = 32 * 32
    d_out = 10

    # Set up train and val loader
    mbsize = 128  # Takes ~8min per epoch, takes ~2300Mb of GPU memory
    # mbsize = 512  # Inexplicably, each epoch now takes ~10min and GPU memory is just 4200
    train_loader = DataLoader(train_dataset, batch_size=mbsize, shuffle=True,
                            pin_memory=True, drop_last=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=mbsize, shuffle=False,
                            pin_memory=True, drop_last=False, num_workers=4)
    
    # Number of features to select
    max_features = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35)

    device = torch.device('cuda', args.gpu)
    results_dict = {
        'acc': {num:[] for num in max_features}
    }
    for num in max_features:
        print(num)
        for num_restart in range(args.num_restarts):
            print(f'Number of restart: {num_restart+1}')

            # Set up model
            # Shared backbone
            backbone = ResNet18Backbone()
            # Classifier head
            classifier_head = ResNet18ClassifierHead()
            # Mask layer
            mask = torch.tensor(generate_random_mask(mask_width=8, num_features=num)).float().to(device)
            mask_layer = StaticMaskLayer2d(mask=mask, mask_width=8, patch_size=4)
            model = nn.Sequential(mask_layer, backbone, classifier_head)
            trainer = BaseModel(model).to(device)

            trainer.fit(train_dataset,
                        val_dataset,
                        mbsize=128,
                        lr=1e-3,
                        nepochs=100,
                        loss_fn=nn.CrossEntropyLoss(),
                        verbose=False)
                
            # Calculate accuracy
            test_acc = trainer.evaluate(test_dataset, Accuracy(), 1024)
            print(f'Acc = {100*test_acc:.2f}\n')
            results_dict['acc'][num].append(test_acc)
        mean_acc = np.mean(results_dict['acc'][num])
        print(f'Average Acc = {100*mean_acc:.2f}\n')

    with open('results/random_mask_results.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
