import argparse
import os
import random
import sys
import time
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision
from torch.utils.data.sampler import SubsetRandomSampler

# Prevent python from saving out .pyc files
sys.dont_write_bytecode = True
# Add models and tasks to path
sys.path.insert(0, './models')
sys.path.insert(0, './tasks')
sys.path.insert(0, './save')
# Logging utility
from util import log

noise_dict = {'blackwhite': transforms.Grayscale(num_output_channels=3),
              'jitter': transforms.ColorJitter(brightness=.5, hue=.3),
              'blur': transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2)),
              'invert': transforms.RandomInvert(p=1),
              'vflip': transforms.RandomVerticalFlip(p=1),
              'rot90': transforms.RandomRotation(degrees=(89, 91))}

noise_transform_dict = {'blackwhite': transforms.RandomGrayscale(p=0.25),
                        'jitter': transforms.ColorJitter(brightness=.5, hue=.3),
                        'blur': transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2)),
                        'invert': transforms.RandomInvert(p=0.25),
                        'vflip': transforms.RandomVerticalFlip(p=0.25),
                        'rot90': transforms.RandomRotation(degrees=90)}


class CustomizedTransform(nn.Module):
    def __init__(self, excluded_transform):
        super(CustomizedTransform, self).__init__()
        self.noise_transform_dict = {'blackwhite': transforms.RandomGrayscale(p=0.25),
                                    'jitter': transforms.ColorJitter(brightness=.5, hue=.3),
                                    'blur': transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 2)),
                                    'invert': transforms.RandomInvert(p=0.25),
                                    'vflip': transforms.RandomVerticalFlip(p=0.25),
                                    'rot90': transforms.RandomRotation(degrees=90)}
        self.noise_list = [item for item in self.noise_transform_dict.keys() if item != excluded_transform]

    def forward(self, img):
        noise = random.choice(self.noise_list)
        return self.noise_transform_dict[noise](img)


# Method for creating directory if it doesn't exist yet
def check_path(path):
    if not os.path.exists(path):
        os.mkdir(path)


def train(args, model, device, optimizer, epoch, train_loader):
    # Set to training mode
    model.train()
    # Iterate over batches

    loss_sum = 0
    acc_sum = 0
    for batch_idx, (x_seq, y) in enumerate(train_loader):
        # Load data to device
        x_seq = x_seq.to(device)
        if args.train_noise != 'no':
            x_seq = noise_dict[args.train_noise](x_seq)
        y = y.to(device)
        # Zero out gradients for optimizer
        optimizer.zero_grad()
        # Run model
        y_pred_linear, y_pred = model(x_seq, device)
        # Loss
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(y_pred_linear, y.long())

        # Update model
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

        loss_sum += loss
        acc = torch.eq(y_pred, y).float().mean().item() * 100.0
        acc_sum += acc
    log.info('[Epoch: ' + str(epoch) + '] ' + \
             '[Loss = ' + '{:.4f}'.format((loss_sum / (batch_idx + 1)).item()) + '] ' + \
             '[Accuracy = ' + '{:.2f}'.format((acc_sum) / (batch_idx + 1)) + '] ')


def test(args, model, device, test_loader):
    log.info('Evaluating on test set...')
    # Set to eval mode
    model.eval()

    # Iterate over batches
    all_acc = []
    all_loss = []
    for batch_idx, (x_seq, y) in enumerate(test_loader):

        x_seq = x_seq.to(device)
        y = y.to(device)

        # Run model
        y_pred_linear, y_pred = model(x_seq, device)
        # Loss
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(y_pred_linear, y.long())

        all_loss.append(loss.item())
        # Accuracy
        acc = torch.eq(y_pred, y).float().mean().item() * 100.0
        all_acc.append(acc)
    # Report overall test performance
    avg_loss = np.mean(all_loss)
    avg_acc = np.mean(all_acc)

    log.info('[Test Summary] ' + \
             '[Loss = ' + '{:.4f}'.format(avg_loss) + '] ' + \
             '[Accuracy = ' + '{:.2f}'.format(avg_acc) + ']')


def main():
    # Settings
    parser = argparse.ArgumentParser()
    # Model settings
    parser.add_argument('--model_name', type=str, default='ViT-InLay')
    parser.add_argument('--norm_type', type=str, default='nonorm',
                        help="{'nonorm', 'contextnorm', 'tasksegmented_contextnorm'}")
    parser.add_argument('--patch_size', type=int, default=8)
    parser.add_argument('--vit_patch_size', type=int, default=8)
    parser.add_argument('--attention', type=str, default='inner')
    parser.add_argument('--depth', type=int, default=6)
    parser.add_argument('--skip_connection', type=int, default=0)
    parser.add_argument('--train_value', type=int, default=1)
    parser.add_argument('--std', type=float, default=1)
    parser.add_argument('--activation', type=str, default='tanh')
    parser.add_argument('--project_higher', type=int, default=1)
    parser.add_argument('--ignore_diag', type=int, default=1)
    parser.add_argument('--save_model', type=int, default=0)

    # Training settings
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--augmentation', type=int, default=0)
    parser.add_argument('--num_train', type=int, default=50000)
    parser.add_argument('--train_batch_size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-4)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--log_interval', type=int, default=10)

    parser.add_argument('--train_noise', type=str, default='no')

    # Test settings
    parser.add_argument('--num_test', type=int, default=20000, help="number of testing data points")
    parser.add_argument('--test_batch_size', type=int, default=100)
    parser.add_argument('--test_noise', type=str)

    # Device settings
    parser.add_argument('--no-cuda', action='store_true', default=False)
    parser.add_argument('--device', type=int, default=0)
    args = parser.parse_args()

    # Set up cuda
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda:" + str(args.device) if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # Convert to PyTorch DataLoaders

    if args.dataset == 'cifar10':
        if args.augmentation == 0:
            train_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        else:
            train_transform = transforms.Compose(
                [CustomizedTransform(excluded_transform=args.test_noise),
                 transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        if args.test_noise != 'no':
            test_transform = transforms.Compose(
                [noise_dict[args.test_noise],
                 transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
        else:
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=train_transform)
        subset_indices = random.sample(range(50000), args.num_train)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size,
                                                   num_workers=2, sampler=SubsetRandomSampler(subset_indices))

        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                                download=True, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
                                                   shuffle=True, num_workers=2)
    elif args.dataset == 'cifar100':
        if args.augmentation == 0:
            train_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
        else:
            train_transform = transforms.Compose(
                [CustomizedTransform(excluded_transform=args.test_noise),
                 transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        if args.test_noise != 'no':
            test_transform = transforms.Compose(
                [noise_dict[args.test_noise],
                 transforms.ToTensor(),
                 transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
        else:
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

        trainset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=True,
                                                download=True, transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size,
                                                   num_workers=2, shuffle=True)

        testset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=False,
                                               download=True, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
                                                  shuffle=True, num_workers=2)
    elif args.dataset == 'svhn':
        if args.augmentation == 0:
            train_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))])
        else:
            train_transform = transforms.Compose(
                [CustomizedTransform(excluded_transform=args.test_noise),
                 transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        if args.test_noise != 'no':
            test_transform = transforms.Compose(
                [noise_dict[args.test_noise],
                 transforms.ToTensor(),
                 transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))])
        else:
            test_transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))])

        trainset = torchvision.datasets.SVHN(root='./data/svhn', split='train',
                                                download=True, transform=train_transform)
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size,
                                                   num_workers=2, shuffle=True)

        testset = torchvision.datasets.SVHN(root='./data/svhn', split='test',
                                               download=True, transform=test_transform)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size,
                                                  shuffle=True, num_workers=2)

    # Create model
    model_class = __import__(args.model_name)
    model = model_class.Model(args).to(device)
    model = model.cuda()
    model = nn.DataParallel(model)
    # model = nn.parallel.DistributedDataParallel(model)

    # Append relevant hyperparameter values to model name
    args.model_name = args.model_name + '_' + args.norm_type + '_lr' + str(args.lr)

    # Create optimizer
    log.info('Setting up optimizer...')
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # Train
    log.info('Training begins...')
    for epoch in range(1, args.epochs + 1):
        # Training loop
        train(args, model, device, optimizer, epoch, train_loader)
        if epoch % 10 == 0:
            test(args, model, device, test_loader)
            if args.save_model == 1:
                torch.save(model.state_dict(), f'./save/{args.model_name}.pth')
    # Test model
    test(args, model, device, test_loader)


if __name__ == '__main__':
    main()
