from __future__ import print_function
import argparse
import csv
import os
import time
import numpy as np
import torch
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import models
from torch.utils.data import Subset
import random
import matplotlib.pyplot as plt
import torch.nn.functional as F


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--save', action='store_true',
                    help='save checkpoint')
parser.add_argument('--model', default="ResNet18", type=str,
                    help='model type (default: ResNet18)')
parser.add_argument('--loss', default='ce', choices=['ce', 'mse'])
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--batch-size', default=125, type=int, help='batch size')
parser.add_argument('--epoch', default=200, type=int,
                    help='total epochs to run')
parser.add_argument('--no-augment', dest='augment', action='store_false',
                    help='use standard augmentation (default: True)')
parser.add_argument('--decay', default=1e-4, type=float, help='weight decay')
parser.add_argument('--alpha', default=1.0, type=float,
                    help='mixup interpolation coefficient (default: 1)')
parser.add_argument('--proportion', default=1.0, type=float,
                    help='proportion of samples from original training set used in training')
parser.add_argument('--plot', action='store_true')
args = parser.parse_args()


classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


use_cuda = torch.cuda.is_available()
cuda_count = int(torch.cuda.device_count())


best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch


train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
real_loss_list = []


random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed_all(args.seed)
    cudnn.benchmark = True


print('==> Program starts..')
print('proportion: {}'.format(args.proportion))
print('alpha: {}'.format(args.alpha))
print('seed: {}'.format(args.seed))
if use_cuda:
    print('using CUDA..')
    print('num of devices: {}\n'.format(cuda_count))


# Data
print('==> Preparing data..')
if args.augment:
    print('Using augmentation\n')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
else:
    print('No augmentation\n')
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='../datasets/CIFAR-10/', train=True, download=False,
                            transform=transform_train)

if args.proportion == 1.0:
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.batch_size,
                                              shuffle=True, num_workers=6)
else:
    indices_airplane = []
    indices_automobile = []
    indices_bird = []
    indices_cat = []
    indices_deer = []
    indices_dog = []
    indices_frog = []
    indices_horse = []
    indices_ship = []
    indices_truck = []

    idx_airplane = trainset.class_to_idx['airplane']
    idx_automobile = trainset.class_to_idx['automobile']
    idx_bird = trainset.class_to_idx['bird']
    idx_cat = trainset.class_to_idx['cat']
    idx_deer = trainset.class_to_idx['deer']
    idx_dog = trainset.class_to_idx['dog']
    idx_frog = trainset.class_to_idx['frog']
    idx_horse = trainset.class_to_idx['horse']
    idx_ship = trainset.class_to_idx['ship']
    idx_truck = trainset.class_to_idx['truck']

    for i in range(len(trainset)):
        current_class = trainset[i][1]
        if current_class == idx_airplane:
            indices_airplane.append(i)
        elif current_class == idx_automobile:
            indices_automobile.append(i)
        elif current_class == idx_bird:
            indices_bird.append(i)
        elif current_class == idx_cat:
            indices_cat.append(i)
        elif current_class == idx_deer:
            indices_deer.append(i)
        elif current_class == idx_dog:
            indices_dog.append(i)
        elif current_class == idx_frog:
            indices_frog.append(i)
        elif current_class == idx_horse:
            indices_horse.append(i)
        elif current_class == idx_ship:
            indices_ship.append(i)
        elif current_class == idx_truck:
            indices_truck.append(i)

    indices_airplane = indices_airplane[:int(args.proportion * len(indices_airplane))]
    indices_automobile = indices_automobile[:int(args.proportion * len(indices_automobile))]
    indices_bird = indices_bird[:int(args.proportion * len(indices_bird))]
    indices_cat = indices_cat[:int(args.proportion * len(indices_cat))]
    indices_deer = indices_deer[:int(args.proportion * len(indices_deer))]
    indices_dog = indices_dog[:int(args.proportion * len(indices_dog))]
    indices_frog = indices_frog[:int(args.proportion * len(indices_frog))]
    indices_horse = indices_horse[:int(args.proportion * len(indices_horse))]
    indices_ship = indices_ship[:int(args.proportion * len(indices_ship))]
    indices_truck = indices_truck[:int(args.proportion * len(indices_truck))]

    trainset = Subset(trainset,
                      indices_airplane + indices_automobile + indices_bird + indices_cat + indices_deer +
                      indices_dog + indices_frog + indices_horse + indices_ship + indices_truck)

    trainloader = torch.utils.data.DataLoader(trainset,
                                            batch_size=args.batch_size,
                                            shuffle=True, num_workers=6)

testset = datasets.CIFAR10(root='../datasets/CIFAR-10', train=False, download=False,
                           transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=6)


# Model
print('==> Building model..')
print('model: {}'.format(args.model))
net = models.__dict__[args.model]()
if use_cuda:
    net.cuda()
if cuda_count > 1:
    net = torch.nn.DataParallel(net)


criterion = nn.CrossEntropyLoss() if args.loss == 'ce' else nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9,
                      weight_decay=args.decay)


results_dir = 'results_CIFAR10_mixup-losses-alpha_{}_{}_{}loss_a{}_p{}_e{}'.format(
    args.model,
    'aug' if args.augment else 'no-aug',
    'CE' if isinstance(criterion, nn.CrossEntropyLoss) else 'MSE',
    args.alpha, args.proportion, args.epoch)
try:
   if not os.path.isdir(results_dir):
       os.mkdir(results_dir)
except OSError as err:
   print(err)

logname1 = (results_dir + '/log_train_' + args.model + '_' + str(args.seed) + '.csv')
logname2 = (results_dir + '/log_final_all.csv')
logname3 = (results_dir + '/log_mixup_losses_alpha.csv')


def mixup_label(lamda, target_a, target_b, num_class=10):
    batch_size = target_a.shape[0]
    onehot_size = torch.Size((batch_size, num_class))

    onehot_a = torch.zeros(size=onehot_size, device=target_a.device)
    onehot_b = torch.zeros(size=onehot_size, device=target_b.device)

    onehot_a.scatter_(1, target_a.data.unsqueeze(1), lamda)
    onehot_b.scatter_(1, target_b.data.unsqueeze(1), 1 - lamda)

    return onehot_a + onehot_b


def onehot_transform(target):
    batch_size = target.shape[0]
    onehot_size = torch.Size((batch_size, 10))

    onehot = torch.zeros(size=onehot_size, dtype=float, device=target.device)
    onehot.scatter_(1, target.data.unsqueeze(1), 1.)

    return onehot


def mixup_data(x, y, alpha, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def train(epoch):
    print('\nEpoch: {}'.format(epoch))
    net.train()
    train_loss = 0
    real_loss = 0
    correct = 0
    total = 0

    if isinstance(criterion, nn.CrossEntropyLoss):
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            batch_size = targets.shape[0]
            total += batch_size

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            mixup_inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.alpha, use_cuda)
            mixup_inputs, targets_a, targets_b, inputs, targets = map(Variable,
                                                                      (mixup_inputs, targets_a, targets_b, inputs, targets))
            mixup_outputs = net(mixup_inputs)
            loss = mixup_criterion(criterion, mixup_outputs, targets_a, targets_b, lam)
            train_loss += loss.item() * batch_size

            with torch.no_grad():
                outputs = net(inputs)
                real_loss += criterion(outputs, targets).item() * batch_size

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            correct += predicted.eq(targets.data).cpu().sum().float()

    elif isinstance(criterion, nn.MSELoss):
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            batch_size = targets.shape[0]
            total += batch_size

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            mixup_inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.alpha, use_cuda)
            mixup_labels = mixup_label(lam, targets_a, targets_b)
            onehots = onehot_transform(targets)
            mixup_inputs, mixup_labels, inputs, onehots = map(Variable, (mixup_inputs, mixup_labels, inputs, onehots))

            mixup_outputs = F.softmax(net(mixup_inputs))
            loss = criterion(mixup_outputs, mixup_labels)
            train_loss += loss.item() * batch_size

            with torch.no_grad():
                outputs = F.softmax(net(inputs))
                real_loss += criterion(outputs, onehots).item() * batch_size

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            correct += predicted.eq(targets.data).cpu().sum().float()

    return train_loss/total, 100.*correct/total, real_loss/total


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    if isinstance(criterion, nn.CrossEntropyLoss):
        for batch_idx, (inputs, targets) in enumerate(testloader):
            batch_size = targets.shape[0]
            total += batch_size

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = Variable(inputs), Variable(targets)

            with torch.no_grad():
                outputs = net(inputs)
                loss = criterion(outputs, targets)

            test_loss += loss.item() * batch_size
            _, predicted = torch.max(outputs.data, 1)
            correct += predicted.eq(targets.data).cpu().sum()

    elif isinstance(criterion, nn.MSELoss):
        for batch_idx, (inputs, targets) in enumerate(testloader):
            batch_size = targets.shape[0]
            total += batch_size

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            onehot = onehot_transform(targets)
            inputs, onehot = Variable(inputs), Variable(onehot)

            with torch.no_grad():
                outputs = F.softmax(net(inputs))
                loss = criterion(outputs, onehot)

            test_loss += loss.item() * batch_size
            _, predicted = torch.max(outputs.data, 1)
            correct += predicted.eq(targets.data).cpu().sum()
    # if (epoch == start_epoch + args.epoch - 1 or acc > best_acc) and args.save:
    #     checkpoint(acc, epoch)
    # if acc > best_acc:
    #     best_acc = acc

    return test_loss/total, 100.*correct/total


def checkpoint(test_acc, train_loss, epoch):
    # Save checkpoint.
    print('Saving..')
    state = {
        'epoch': epoch,
        'state_dict': net.module.state_dict() if cuda_count > 1 else net.state_dict(),
        'test_acc': test_acc,
        'train_loss': train_loss,
        'cuda_count': cuda_count,
        'torch_rng': torch.get_rng_state(),
        'numpy_rng': np.random.get_state()
    }

    checkpoint_dir = 'checkpoints/mixup_loss_alpha'
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    checkpoint_dir += '/{}_{}_{}loss_a{}_p{}_e{}'.format(
        args.model,
        'aug' if args.augment else 'no-aug',
        'CE' if isinstance(criterion, nn.CrossEntropyLoss) else 'MSE',
        args.alpha, args.proportion, args.epoch)
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    torch.save(state, checkpoint_dir + '/ckpt.t7_' + str(args.seed))

    return None


adjust_point_1 = args.epoch // 2
adjust_point_2 = (args.epoch - adjust_point_1) // 2 + adjust_point_1
adjust_point_3 = (args.epoch - adjust_point_2) // 2 + adjust_point_2
adjust_point_4 = (args.epoch - adjust_point_3) // 2 + adjust_point_3

if args.epoch < 400:
    lr_division_1, lr_division_2 = 1, 1
elif 400 <= args.epoch < 600:
    lr_division_1, lr_division_2 = 10, 1
elif args.epoch >= 600:
    lr_division_1, lr_division_2 = 5, 2

def adjust_learning_rate(optimizer, epoch):
    lr = args.lr

    global adjust_point_1
    global adjust_point_2
    global adjust_point_3
    global adjust_point_4

    global lr_division_1
    global lr_division_2

    if epoch >= adjust_point_1:
        lr /= 10
    if epoch >= adjust_point_2:
        lr /= 10
    if epoch >= adjust_point_3:
        lr /= lr_division_1
    if epoch >= adjust_point_4:
        lr /= lr_division_2

    if epoch == adjust_point_1:
        print('adjusting learning rate for the 1st time')
    if epoch == adjust_point_2:
        print('adjusting learning rate for the 2nd time')
    if epoch == adjust_point_3 and args.epoch >= 400:
        print('adjusting learning rate for the 3rd time')
    if epoch == adjust_point_4 and args.epoch >= 600:
        print('adjusting learning rate for the 4th time')

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    print('lr: {}'.format(lr))

    return None


def plot_curve(epoch_min):
    x = range(0, args.epoch)

    y1 = train_loss_list
    y2 = train_acc_list
    y3 = test_loss_list
    y4 = test_acc_list
    y5 = real_loss_list

    plt.figure(figsize=(36, 18))
    plt.title('Losses & Accuracies vs Epochs')
    plt.xticks(fontsize=20)

    plt.subplot(2, 2, 1)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.plot(x, y1, '.-')
    plt.axvline(x=epoch_min, color='red', linewidth=0.5)
    plt.ylim(0, 2)
    plt.ylabel('Training Loss', fontsize=30)

    plt.subplot(2, 2, 2)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.plot(x, y3, '.-')
    plt.axvline(x=epoch_min, color='red', linewidth=0.5)
    plt.ylim(0, 2)
    plt.ylabel('Testing Loss', fontsize=30)

    plt.subplot(2, 2, 3)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.plot(x, y2, '.-')
    plt.axvline(x=epoch_min, color='red', linewidth=0.5)
    plt.ylim(0, 100)
    plt.xlabel('Epoch', fontsize=30)
    plt.ylabel('Training Accuracy', fontsize=30)

    plt.subplot(2, 2, 4)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.plot(x, y4, '.-')
    plt.axvline(x=epoch_min, color='red', linewidth=0.5)
    plt.ylim(0, 100)
    plt.xlabel('Epoch', fontsize=30)
    plt.ylabel('Testing Accuracy', fontsize=30)

    plt.savefig(results_dir + '/CIFAR10_{}_{}_{}_a{}_p{}_e{}_losses_accs_{}.png'.format(
        args.model,
        'aug' if args.augment else 'no-aug',
        args.loss,
        args.alpha, args.proportion, args.epoch, args.seed), dpi=100)
    plt.close('all')

    return None


if not os.path.exists(logname1):
    with open(logname1, 'w') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow(['epoch', 'train loss', 'train acc', 'test loss', 'test acc', 'real loss'])

try:
    if not os.path.exists(logname2):
        with open(logname2, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(['seed', 'train loss', 'train acc', 'test acc', 'real loss', 'test loss'])
except OSError as err:
   print(err)

try:
    if not os.path.exists(logname3):
        with open(logname3, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(['seed', 'initial train loss', 'min train loss',
                                'epoch min', 'test acc', 'train acc', 'real loss', 'test loss'])
except OSError as err:
   print(err)


def main():
    epoch_min = start_epoch

    for epoch in range(start_epoch, args.epoch):
        train_loss, train_acc, real_loss = train(epoch)
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        real_loss_list.append(real_loss)

        test_loss, test_acc = test(epoch)
        test_loss_list.append(test_loss)
        test_acc_list.append(test_acc)

        if epoch == 0:
            train_loss_best = train_loss

        if train_loss < train_loss_best:
            train_loss_best = train_loss
            train_acc_min = train_acc
            test_acc_min = test_acc
            real_loss_min = real_loss
            test_loss_min = test_loss
            epoch_min = epoch

            if args.save:
                checkpoint(test_acc_min, train_loss_best, epoch_min)

        train_acc_float = train_acc.item()
        test_acc_float = test_acc.item()

        adjust_learning_rate(optimizer, epoch)

        with open(logname1, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow([epoch, train_loss, train_acc_float, test_loss, test_acc_float, real_loss])

    with open(logname2, 'a') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow([args.seed, train_loss, train_acc_float, test_acc_float, real_loss, test_loss])

    with open(logname3, 'a') as logfile:
        logwriter = csv.writer(logfile, delimiter=',')
        logwriter.writerow([args.seed, '', train_loss_best,
                            epoch_min, test_acc_min.item(), train_acc_min.item(), real_loss_min, test_loss_min])

    if args.plot:
        plot_curve(epoch_min)

    return None


if __name__ == '__main__':
    main()
