import argparse
import numpy as np
import random
import os
import sys
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter

from attacks import *
from models import *

from utils.dir_utils import *
from utils.logging_utils import *
from utils.tinyimages_80mn_loader import *

# Arguments
parser = argparse.ArgumentParser()

# Directory
parser.add_argument('--data_dir', default='/data_large/readonly', help='In-distribution data path')
parser.add_argument('--exp_dir', default='./src/experiments', help='Path you want to store checkpoint and log')

# Training
parser.add_argument('--batch_size', default=128, type=int, help='The batch size of in-distribution samples')
parser.add_argument('--oe_batch_size', default=256, type=int, help='The batch size of out-distribution samples')
parser.add_argument('--max_epoch', default=10, type=int, help='The number of epochs')
parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay')
parser.add_argument('--resume', default='./src/checkpoints/cifar10/wideresnet/s_pgd_adv_l2_0.25/ckpt.pt', type=str, help='The location of checkpoint you want to fine-tune')
parser.add_argument('--print_every', default=10, type=int, help='Print frequency (in step)')
parser.add_argument('--save_every', default=1, type=int, help='Save frequency (in epoch)')
parser.add_argument('--alpha', default=1.0, type=float, help='Minimizing MSP')
parser.add_argument('--beta', default=1.0, type=float, help='Maximizing safe spot objective')
parser.add_argument('--seed', default=0, type=int, help='Random seed')

# Model
parser.add_argument('--model_type', default='WideResNet', type=str, help='Network structure')
parser.add_argument('--depth', default=34, type=int, help='Depth')
parser.add_argument('--widen_factor', default=10, type=int, help='Widen factor')
parser.add_argument('--droprate', default=0.0, type=float, help='Dropout rate')

# Attack
parser.add_argument('--attack_type', default='PGDAttackL2', type=str, help='Attack type')
parser.add_argument('--epsilon', default=0.25, type=float, help='Epsilon')
parser.add_argument('--step_size', default=0.0625, type=float, help='Step size')
parser.add_argument('--num_steps', default=7, type=int, help='The number of steps')

args = parser.parse_args()

# Transform (train)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
# Transform (test)
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# Pixel statistics
mean = torch.tensor([0.4914, 0.4822, 0.4465]).cuda()
std = torch.tensor([0.2023, 0.1994, 0.2010]).cuda()


if __name__ == '__main__':
    # Fix random seed
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed) 
    random.seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Directory
    exp_dir = makedir(os.path.join(args.exp_dir)) 
    model_dir = makedir(os.path.join(exp_dir, 'checkpoint'))
    log_dir = makedir(os.path.join(exp_dir, 'log'))

    # Hyperparamerters
    print('\nHyperparameters')
    for key, val in vars(args).items():
        print('{}={}'.format(key, val))
    save_params(exp_dir, args)

    # Dataset
    print('\nProcessing data...')
    train_set = torchvision.datasets.CIFAR10(
        root=args.data_dir, train=True, download=False, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True)

    test_set = torchvision.datasets.CIFAR10(
        root=args.data_dir, train=False, download=False, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=100, shuffle=False)
  
    ood_dataset = TinyImages(root=os.path.join(args.data_dir, 'tiny_images'), transform=transform_train)
    ood_loader = torch.utils.data.DataLoader(
        ood_dataset, batch_size=args.oe_batch_size, shuffle=False) 

    # Model
    print('\nBuilding model...')
    model_class = getattr(sys.modules[__name__], args.model_type)
    base_model = model_class(depth=args.depth, widen_factor=args.widen_factor, dropRate=args.droprate) 
    model = base_model.cuda()
    model = torch.nn.DataParallel(model)
    model = ModelWrapper(model, mean, std)
    
    # Create attack
    attack_class = getattr(sys.modules[__name__], args.attack_type)
    attack_xent = attack_class(model, args.epsilon, args.step_size, args.num_steps, criterion='xent')
    attack_uniform = attack_class(model, args.epsilon, args.step_size, args.num_steps, criterion='uniform')

    # Create optimizer
    optimizer = optim.SGD(
        model.parameters(), 
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay
    )
    
    def cosine_annealing(step, total_steps, lr_max, lr_min):
        return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

    scheduler= torch.optim.lr_scheduler.LambdaLR(
        optimizer, 
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.max_epoch * len(train_loader),
            1,
            1e-6 / args.lr
        )
    )

    # Load checkpoint
    start_epoch = 0
    print('\nLoading checkpoint...')
    checkpoint = torch.load(args.resume)
    state_dict = checkpoint['state_dict']
    base_model.load_state_dict(state_dict)

    # Tensorboard
    writer = SummaryWriter(logdir=log_dir)

    def train(epoch):
        print('\nEpoch: {}'.format(epoch))
        print('Training...')
    
        batch_time_meter = AverageMeter('time', ':.2f')
        lr_meter = AverageMeter('lr', ':.6f')
        xent_meter = AverageMeter('xent_safe_adv', ':.3f')
        xent_ood_1_meter = AverageMeter('xent_ood', ':.3f')
        xent_ood_2_meter = AverageMeter('xent_ood_safe_adv', ':.3f')
        acc_meter = AverageMeter('acc', ':.3f')

        progress = ProgressMeter(
            len(train_loader),
            [batch_time_meter, lr_meter, xent_meter, xent_ood_1_meter, xent_ood_2_meter, acc_meter],
            prefix="Epoch: [{}]".format(epoch))

        model.train()
        end = time.time()
        
        ood_loader.dataset.offset = np.random.randint(len(ood_loader.dataset))

        for batch_idx, ((inputs, targets), (inputs_ood, targets_ood)) in enumerate(zip(train_loader, ood_loader)):
            inputs, targets = inputs.cuda(), targets.cuda()
            inputs_ood, targets_ood = inputs_ood.cuda(), targets_ood.cuda()
            
            # Split out-distribution samples
            inputs_ood_1, targets_ood_1 = inputs_ood[:targets_ood.size(0)//2], targets_ood[:targets_ood.size(0)//2]
            inputs_ood_2, targets_ood_2 = inputs_ood[targets_ood.size(0)//2:], targets_ood[targets_ood.size(0)//2:]

            # Run attack
            model.eval()    
             
            inputs = attack_xent(inputs, targets, random_start=False, targeted=True).detach()
            inputs = attack_xent(inputs, targets).detach()
            
            outputs_ood_2 = model(inputs_ood_2)
            _, preds_ood_2 = outputs_ood_2.max(1) 
            inputs_ood_2 = attack_xent(inputs_ood_2, preds_ood_2, random_start=False, targeted=True).detach()
            inputs_ood_2 = attack_xent(inputs_ood_2, preds_ood_2).detach()
                
            # Run forward pass
            model.train()
            
            inputs_cat = torch.cat([inputs, inputs_ood_1, inputs_ood_2], dim=0)
            outputs_cat = model(inputs_cat)
            outputs = outputs_cat[:targets.size(0), ...]
            outputs_ood_1 = outputs_cat[targets.size(0):-targets_ood_2.size(0), ...]
            outputs_ood_2 = outputs_cat[-targets_ood_2.size(0):, ...]

            # Compute loss
            xent = F.cross_entropy(outputs, targets)
            xent_ood_1 = torch.mean(-F.log_softmax(outputs_ood_1, dim=1))
            xent_ood_2 = torch.mean(-F.log_softmax(outputs_ood_2, dim=1))
            loss = xent + args.alpha * xent_ood_1 + args.beta * xent_ood_2
            
            # Run backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            # Statistics
            lr = optimizer.param_groups[0]['lr']
            _, preds = outputs.max(1)
            correct = preds.eq(targets).sum().item()
            acc = correct / targets.size(0)
            
            lr_meter.update(lr)
            xent_meter.update(xent.item(), targets.size(0))
            xent_ood_1_meter.update(xent_ood_1.item(), targets_ood_1.size(0))
            xent_ood_2_meter.update(xent_ood_2.item(), targets_ood_2.size(0))
            acc_meter.update(acc, targets.size(0))
            batch_time_meter.update(time.time() - end)
            
            end = time.time()

            # Logging
            if batch_idx % args.print_every == 0:
                progress.display(batch_idx)
       
        # Summary 
        writer.add_scalar('train_xent', xent_meter.avg, epoch) 
        writer.add_scalar('train_acc', acc_meter.avg, epoch) 
    
    def test(epoch):
        print('Testing...')
        
        batch_time_meter = AverageMeter('time', ':.2f')
        xent_meter = AverageMeter('xent', ':.3f')
        acc_meter = AverageMeter('acc', ':.3f')
        
        model.eval()
        end = time.time()
    
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()

            # Run forward pass   
            outputs = model(inputs)
        
            # Compute loss
            xent = F.cross_entropy(outputs, targets)

            # Statistics
            _, preds = outputs.max(1)
            correct = preds.eq(targets).sum().item()
            acc = correct / targets.size(0)
            
            xent_meter.update(xent.item(), targets.size(0))
            acc_meter.update(acc, targets.size(0))
            batch_time_meter.update(time.time() - end)
            
            end = time.time()
        
        # Logging
        print('xent {:.3f}, acc: {:.3f}'.format(
            xent_meter.avg, acc_meter.avg))
        
        # Summary 
        writer.add_scalar('test_xent', xent_meter.avg, epoch) 
        writer.add_scalar('test_acc', acc_meter.avg, epoch) 
         
        # Save checkpoint.
        if epoch % args.save_every == 0:
            print('Saving...')
            state_dict = {
                'epoch': epoch,
                'state_dict': base_model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }
            torch.save(state_dict, os.path.join(model_dir, 'ckpt.pt'))
      
    for epoch in range(start_epoch + 1, args.max_epoch + 1):
        train(epoch)
        test(epoch)

