# Load Data
#%%
import torchvision.datasets as dsets
from torch.utils.data import Subset, random_split
import torchvision.transforms as transforms
import numpy as np
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--samples', type=int, default=8)
parser.add_argument('--data_percent', type=float, default=1)
parser.add_argument('--pair_sampling', action='store_true')
args = parser.parse_args()
#%%
# Transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load train set
train_set = dsets.CIFAR10('./', train=True, download=True, transform=transform_train)
if args.data_percent != 1:
    selected_indices = np.random.choice(len(train_set), size=int(50000*args.data_percent), replace=False)
    train_set = Subset(train_set, selected_indices)
# Load test set (using as validation)
val_set = dsets.CIFAR10('./', train=False, download=True, transform=transform_test)
val_set, test_set = random_split(val_set, [5000, 5000])
#%%
# Train Model
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os.path
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from copy import deepcopy
import sys
sys.path.append('..')
from resnet import ResNet18
# Select device
device = torch.device('cuda')
#%%
# Check for model
if os.path.isfile('cifar resnet.pt'):
    # Load saved model
    print('Loading saved model')
    model = torch.load('cifar resnet.pt').to(device)

else:
    # Create model
    model = ResNet18(num_classes=10).to(device)

    # Training parameters
    lr = 1e-3
    mbsize = 256  # 16
    max_nepochs = 250
    loss_fn = nn.CrossEntropyLoss()
    lookback = 10
    verbose = True

    # Validation function
    val_loader = DataLoader(val_set, batch_size=mbsize, shuffle=False, num_workers=4)

    def validate(model):
        n = 0
        mean_loss = 0
        mean_acc = 0

        for x, y in val_loader:
            # Move to GPU.
            n += len(x)
            x = x.to(device)
            y = y.to(device)

            # Get predictions.
            pred = model(x)

            # Update loss.
            loss = loss_fn(pred, y).item()
            mean_loss += len(x) * (loss - mean_loss) / n

            # Update accuracy.
            acc = (torch.argmax(pred, dim=1) == y).float().mean().item()
            mean_acc += len(x) * (acc - mean_acc) / n

        return mean_loss, mean_acc

    # Data loader
    train_loader = DataLoader(train_set, batch_size=mbsize, shuffle=True,
                                drop_last=True, num_workers=4)

    # Setup
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=lookback // 2, min_lr=1e-5,
        mode='max', verbose=verbose)
    loss_list = []
    acc_list = []
    min_criterion = np.inf
    min_epoch = 0

    # Train
    for epoch in range(max_nepochs):
        for x, y in tqdm(train_loader, desc='Training loop', leave=True):
            # Move to device.
            x = x.to(device=device)
            y = y.to(device=device)

            # Take gradient step.
            loss = loss_fn(model(x), y)
            loss.backward()
            optimizer.step()
            model.zero_grad()

        # Check progress.
        with torch.no_grad():
            # Calculate validation loss.
            model.eval()
            val_loss, val_acc = validate(model)
            model.train()
            if verbose:
                print('----- Epoch = {} -----'.format(epoch + 1))
                print('Val loss = {:.4f}'.format(val_loss))
                print('Val acc = {:.4f}'.format(val_acc))
            loss_list.append(val_loss)
            acc_list.append(val_acc)
            scheduler.step(val_acc)

            # Check convergence criterion.
            val_criterion = - val_acc
            if val_criterion < min_criterion:
                min_criterion = val_criterion
                min_epoch = epoch
                best_model = deepcopy(model)
                print('')
                print('New best epoch, acc = {:.4f}'.format(val_acc))
                print('')
            elif (epoch - min_epoch) == lookback:
                if verbose:
                    print('Stopping early')
                break

    # Keep best model
    model = best_model

    # Save model
    model.cpu()
    torch.save(model, 'cifar resnet.pt')
    model.to(device)

#%%
# Train Surrogate
from fastshap import ImageSurrogate
from fastshap.utils import MaskLayer2d, KLDivLoss, DatasetInputOnly

# Check for model
if os.path.isfile('cifar surrogate.pt'):
    print('Loading saved surrogate model')
    surr = torch.load('cifar surrogate.pt').to(device)
    surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)

else:
    # Create model
    surr = nn.Sequential(
        MaskLayer2d(value=0, append=True),
        ResNet18(in_channels=4, num_classes=10)).to(device)

    # Set up surrogate object
    surrogate = ImageSurrogate(surr, width=32, height=32, superpixel_size=2)

    # Set up datasets
    train_surr = DatasetInputOnly(train_set)
    val_surr = DatasetInputOnly(val_set)
    original_model = nn.Sequential(model, nn.Softmax(dim=1))

    # Train
    surrogate.train_original_model(
        train_surr,
        val_surr,
        original_model,
        batch_size=256,
        max_epochs=100,
        loss_fn=KLDivLoss(),
        lookback=10,
        bar=True,
        verbose=True)

    # Save surrogate
    surr.cpu()
    torch.save(surr, 'cifar surrogate.pt')
    surr.to(device)

#%% Train Simshap
import sys
sys.path.append('..')
from unet import UNet
from simshap.simshap_sampling import SimSHAPSampling

#%%

# Set up explainer model
explainer = UNet(n_classes=10, num_down=2, num_up=1, num_convs=3).to(device)
# Set up FastSHAP object
simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)

# Set up datasets
simshap_train = DatasetInputOnly(train_set)
simshap_val = DatasetInputOnly(val_set)

# Train
simshap.train(
    simshap_train,
    simshap_val, 
    batch_size=args.batch_size,
    lr=args.lr,
    num_samples=args.samples,
    paired_sampling=args.pair_sampling,
    max_epochs=args.epochs,
    validation_samples=8,
    lookback=10,
    bar=True,
    verbose=True)

# Save explainer
explainer.cpu()
torch.save(explainer, 'cifar simshap ablation.pt')
explainer.to(device)
