import argparse
import os
import errno
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import SubsetRandomSampler

from densenet import DenseNet

parser = argparse.ArgumentParser(description='Cifar10 DenseNet')

# Basic Setting
parser.add_argument('--seed', default=1, type = int, help = 'set seed')
parser.add_argument('--data_path', default='./data/', type = str, help = 'path for saving data')
parser.add_argument('--base_path', default='./result/CIFAR10/', type = str, help = 'base path for saving result')
parser.add_argument('--model_path', default='test_run/', type = str, help = 'folder name for saving model')

# Resnet Architecture
parser.add_argument('--depth', default=40, type=int, help='Model depth')
parser.add_argument('--growth_rate', default=12, type=int, help='Model parameters')

# Training Setting
parser.add_argument('--nepoch', default = 300, type = int, help = 'total number of training epochs')
parser.add_argument('--lr_decay_time', default = [150, 225], type = int, nargs= '+', help = 'when to multiply lr by 0.1')
parser.add_argument('--init_lr', default = 0.1, type = float, help = 'initial learning rate')
parser.add_argument('--momentum', default = 0.9, type = float, help = 'momentum in SGD')
parser.add_argument('--weight_decay', default = 0.0001, type = float, help = 'weight decay in SGD')
parser.add_argument('--batch_train', default = 128, type = int, help = 'batch size for training')
parser.add_argument('--batch_val', default = 128, type = int, help = 'batch size for validation')
parser.add_argument('--batch_test', default = 128, type = int, help = 'batch size for testing')

args = parser.parse_args()


class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece




def model_eval(net, data_loader, device, loss_func):
    net.eval()
    total_count = 0
    logits_list = []
    labels_list = []

    ece_criterion = _ECELoss().cuda()

    for cnt, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        logits_list.append(outputs)
        labels_list.append(labels)
        total_count += images.shape[0]

    logits = torch.cat(logits_list).cuda()
    labels = torch.cat(labels_list).cuda()

    loss = loss_func(logits, labels)
    prediction = logits.data.max(1)[1]
    accuracy = prediction.eq(labels.data).sum().item() / total_count
    ece_loss = ece_criterion(logits, labels)

    return  accuracy, loss, ece_loss


def main():
    seed = args.seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    data_path = args.data_path

    depth = args.depth
    growth_rate = args.growth_rate

    block_config = [(depth - 4) // 6 for _ in range(3)]

    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          normalize])
    test_transform = transforms.Compose([transforms.ToTensor(),
                                         normalize])

    dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)
    
    data_seed = 0
    train_set = dataset
    train_set, val_set = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(data_seed))

    test_set = datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform)


    np.random.seed(args.seed)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_train, shuffle=True,num_workers=4)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_val, shuffle=False, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_test, shuffle=False, num_workers=4)

    nval = len(val_loader.dataset)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    loss_func = nn.CrossEntropyLoss().to(device)
    net = DenseNet(growth_rate=growth_rate, block_config=block_config, num_classes=10).to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=args.init_lr, momentum=args.momentum, weight_decay=args.weight_decay)

    PATH = "{}{}densenet_{}/seed{}/".format(args.base_path, args.model_path, args.depth, seed)

    if not os.path.isdir(PATH):
        try:
            os.makedirs(PATH)
        except OSError as exc:  # Python >2.5
            if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                pass
            else:
                raise

    num_epochs = args.nepoch
    train_accuracy_path = np.zeros(num_epochs + 1)
    train_loss_path = np.zeros(num_epochs + 1)
    train_ece_loss_path = np.zeros(num_epochs + 1)

    test_accuracy_path = np.zeros(num_epochs + 1)
    test_loss_path = np.zeros(num_epochs + 1)
    test_ece_loss_path = np.zeros(num_epochs + 1)

    torch.manual_seed(args.seed)

    NTrain = len(train_loader.dataset)
    best_accuracy = 0
    best_valid_accuracy = 0

    torch.save(net.state_dict(), PATH + 'init_model.pt')

    with torch.no_grad():
        epoch = 0
        train_accuracy, train_loss, train_ece_loss = model_eval(net, train_loader, device, loss_func)
        train_loss_path[epoch] = train_loss
        train_ece_loss_path[epoch] = train_ece_loss
        train_accuracy_path[epoch] = train_accuracy
        print("epoch: ", epoch, ", train loss: ", train_loss, ", train ece loss: ", train_ece_loss, "train accuracy: ", train_accuracy)

        test_accuracy, test_loss, test_ece_loss = model_eval(net, test_loader, device, loss_func)
        test_loss_path[epoch] = test_loss
        test_ece_loss_path[epoch] = test_ece_loss
        test_accuracy_path[epoch] = test_accuracy
        print("epoch: ", epoch, ", test loss: ", test_loss, ", test ece loss: ", test_ece_loss, "test accuracy: ", test_accuracy)


    for epoch in range(num_epochs):
        net.train()
        epoch_training_loss = 0.0
        total_count = 0
        accuracy = 0

        if epoch in args.lr_decay_time:
            for para in optimizer.param_groups:
                para['lr'] = para['lr'] * 0.1

        for i, (input, target) in enumerate(train_loader):
            input, target = input.to(device), target.to(device)
            output = net(input)
            loss = loss_func(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_training_loss += loss.mul(input.shape[0]).item()
            accuracy += output.data.argmax(1).eq(target.data).sum().item()
            total_count += input.shape[0]
            train_loss_path[epoch] = epoch_training_loss / total_count
            train_accuracy_path[epoch] = accuracy / total_count
        print("epoch: ", epoch, ", train loss: ", epoch_training_loss / total_count, "train accuracy: ",
              accuracy / total_count)

        # calculate test set accuracy
        with torch.no_grad():

            test_accuracy, test_loss, test_ece_loss = model_eval(net, test_loader, device, loss_func)
            test_loss_path[epoch + 1] = test_loss
            test_ece_loss_path[epoch + 1] = test_ece_loss
            test_accuracy_path[epoch + 1] = test_accuracy
            print("epoch: ", epoch, ", test loss: ", test_loss, ", test ece loss: ", test_ece_loss, "test accuracy: ", test_accuracy)


            if test_accuracy > best_accuracy:
                best_accuracy = test_accuracy
                torch.save(net.state_dict(), PATH + 'best_model.pt')

            print('best accuracy:', best_accuracy)

        torch.save(net.state_dict(), PATH + 'model' + str(epoch) + '.pt')

        import pickle
        filename = PATH + 'result.txt'
        f = open(filename, 'wb')
        pickle.dump([train_loss_path, train_ece_loss_path, train_accuracy_path, test_loss_path, test_ece_loss_path, test_accuracy_path], f)
        f.close()



if __name__ == '__main__':
    main()
