'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pickle
from sklearn.metrics import log_loss, brier_score_loss
from scipy.special import expit


from models import *
from utils import progress_bar
from baseline_trainable import BaselineTrainable, BaselineEnsemble, FocalLoss, EntropyRegularizedLoss, MMCELoss
from baseline_method_v2 import Our_method, calibrate_model
import h5py
from tqdm import tqdm


parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--model-name', default='ckpt.pth', type=str)
parser.add_argument('--model-dir', default='crossentropy', type=str, required=True)
parser.add_argument('--label_type', default='equal', type=str)
parser.add_argument('--methods', default='ce', type=str)
parser.add_argument('--dataset', default='cifar10', type=str)
args = parser.parse_args()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


def unpickling(file):
    return pickle.load(open(file, 'rb'))


class Face_dataset(torch.utils.data.Dataset):
    def __init__(self, root, prob_type, mode='train', std_classifier=False):
        with h5py.File(root + mode + '_im.h5', 'r') as f:
            self.keys = list(f.keys())
        self.prob_type = prob_type
        self.root = root
        self.mode = mode
        self.proposed_probs = None
        self.std_classifier = std_classifier
        if not hasattr(self, 'img_data'):
            self.open_hdf5()

    def open_hdf5(self):
        self.img_data = h5py.File(self.root + self.mode + '_im.h5', 'r')
        if self.prob_type == 'eq':
            self.target_10 = h5py.File(self.root + self.mode + '_' + self.prob_type + '_prob.h5', 'r')
        else:
            self.target_10 = h5py.File(self.root + "labels_" + self.prob_type + '_' + self.mode + '.h5', 'r')
        if self.prob_type in {'eq', 'unif'}:
            self.prob_sim_func = lambda x: x
        elif self.prob_type == 'sig':
            self.prob_sim_func = lambda x: expit((x - 0.29) * 25)
        elif self.prob_type == 'scaled':
            self.prob_sim_func = lambda x: x / 2.5
        elif self.prob_type == 'mid':
            self.prob_sim_func = lambda x: x / 3.0 + 0.35
        elif self.prob_type == 'step':
            self.prob_sim_func = lambda x: (x < 0.2) * 0.1 + ((x >= 0.2) & (x < 0.4)) * 0.3 + \
                    ((x >= 0.4) & (x < 0.6)) * 0.5 + ((x >= 0.6) & (x < 0.8)) * 0.7 + (x >= 0.8) * 0.9
        else:
            raise NotImplementedError

        self.target_prob = h5py.File(self.root + self.mode + '_label.h5', 'r')

    #         self.keys = list(self.img_data.keys())
    def __getitem__(self, index):
        data = (torch.tensor(self.img_data[self.keys[index]]).clone().permute(2, 0, 1)) / 255
        target_10 = torch.tensor(np.array(self.target_10[self.keys[index]])).clone()
        target_prob = self.prob_sim_func(
            torch.minimum(torch.tensor(self.target_prob[self.keys[index]][0, 0]) / 100.0, torch.tensor(1.)))
        if self.prob_type == 'ineq':
            if target_prob <= 0.5:
                target_prob /= 2
        if self.proposed_probs is not None:
            probs = self.proposed_probs[index]
        else:
            probs = 0
        if self.std_classifier:
            target_10 = (10 * target_prob).long()
        return data, target_10, probs, index, target_prob

    def __len__(self):
        return len(self.keys)

assert args.label_type in {'unif', 'sig', 'eq', 'scaled', 'mid', 'step'}
trainset = Face_dataset(root='./Faces_detection/',prob_type=args.label_type,mode = 'train',std_classifier=False)
valset = Face_dataset(root='./Faces_detection/',prob_type=args.label_type,mode = 'val',std_classifier=False)
testset = Face_dataset(root='./Faces_detection/',prob_type=args.label_type,mode = 'test',std_classifier=False)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=4)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=128, shuffle=False, num_workers=4)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
if args.dataset == 'face':
    net = torchvision.models.resnet18(num_classes=2)
else:
    net = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=2)
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    # cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/'+args.model_name)
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

# weights = [1/0.9,10]
# class_weights = torch.FloatTensor(weights).to(device)
# criterion = nn.CrossEntropyLoss(weight=class_weights)
criterion = nn.CrossEntropyLoss()


optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
# optimizer = optim.SGD(net.parameters(), lr=args.lr,
#                       momentum=0.9)
# optimizer = optim.Adam(net.parameters(), lr=args.lr,
#                        weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=5, factor=0.5, \
                                                                    min_lr=1e-6, verbose=True)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets, _, _, _) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        # progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        #              % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    print("Loss: %.3f" % (train_loss/(batch_idx+1)))


def test(net, dataloader):
    global best_acc
    net.eval()
    targets_probs = np.zeros(len(dataloader.dataset))
    labels = np.zeros(len(dataloader.dataset))
    indices = np.zeros(len(dataloader.dataset))
    gt_labels = np.zeros(len(dataloader.dataset))
    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, label, _, idx, gt_label) in enumerate(dataloader):
            inputs = inputs.to(device)
            outputs = net(inputs)
            out_prob = F.softmax(outputs, dim=1)
            targets_probs[idx] = out_prob[:, 1].cpu().numpy()
            labels[idx] = label
            gt_labels[idx] = gt_label
    return targets_probs, labels

print(os.path.join(args.model_dir, args.model_name))

if args.methods == 'ce':
    min_val_loss = 1e10
    for epoch in range(start_epoch, start_epoch+200):
        train(epoch)
        val_targets_probs, labels = test(net, valloader)

        val_loss = log_loss(y_true=labels, y_pred=val_targets_probs)
        if min_val_loss > val_loss:
            print('Saving..')
            state = {
                'net': net.state_dict(),
                'val_loss': val_loss,
                'epoch': epoch,
            }
            print('val_loss: {:.3f}'.format(val_loss))
            if not os.path.isdir(args.model_dir):
                os.mkdir(args.model_dir)
            torch.save(state, os.path.join(args.model_dir, args.model_name))
            min_val_loss = val_loss
        scheduler.step(val_loss)

elif args.methods == 'deepensemble':
    M = 5
    adversarial_epsilon = 0.01
    if not os.path.isdir(args.model_dir):
        os.mkdir(args.model_dir)
    baseline = BaselineEnsemble(net, M, optimizer, criterion, trainset, valset, adversarial_epsilon=adversarial_epsilon,
                                 save_dir=os.path.join(args.model_dir, args.model_name), num_epoch=200)
    baseline.fit()

elif args.methods == 'ours_bin':
    if not os.path.isdir(args.model_dir):
        os.mkdir(args.model_dir)
    lr = 1e-4
    net = torchvision.models.resnet18(num_classes=2)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
    model_path = "./checkpoints/{}/ce/ckpt.pth".format(args.label_type)
    checkpoint = torch.load(model_path)['net']
    net.load_state_dict(checkpoint)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    m_kwargs = {
        "net": net,  # early stopped network
        "optimizer": optimizer,  # optimizer for finetuning
        "train_dataset": trainset,  # train set
        "val_dataset": valset,  # validation set for finetuning stopping
        "num_epoch": 200,  # max number of epochs for finetuning
        "n_bins": 15,  # number of bins for updated probabilistic labels
        "calpertrain": 2,
        "finetune_type": "bin",
        "save_dir": os.path.join(args.model_dir, args.model_name)
    }
    calibrate_model(Our_method, m_kwargs=m_kwargs, test_dataset=testset)

elif args.methods == 'ours_kd':
    if not os.path.isdir(args.model_dir):
        os.mkdir(args.model_dir)
    lr = 1e-4
    net = torchvision.models.resnet18(num_classes=2)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
    model_path = "./checkpoints/{}/ce/ckpt.pth".format(args.label_type)
    checkpoint = torch.load(model_path)['net']
    net.load_state_dict(checkpoint)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    m_kwargs = {
        "net": net,  # early stopped network
        "optimizer": optimizer,  # optimizer for finetuning
        "train_dataset": trainset,  # train set
        "val_dataset": valset,  # validation set for finetuning stopping
        "num_epoch": 200,  # max number of epochs for finetuning
        "calpertrain": 2,
        "finetune_type": "kde",
        "sigma": 0.05,
        "window": 500,
        "save_dir": os.path.join(args.model_dir, args.model_name)
    }
    calibrate_model(Our_method, m_kwargs=m_kwargs, test_dataset=testset)

else:
    if args.methods == "focal":
        criterion = FocalLoss(alpha=None, gamma=2.0)
    elif args.methods == "entropy":
        criterion = EntropyRegularizedLoss(beta=1.0)
    elif args.methods == "MMCE":
        criterion = MMCELoss(beta=3.0)
    if not os.path.isdir(args.model_dir):
        os.mkdir(args.model_dir)
    baseline = BaselineTrainable(net, optimizer, criterion, trainset, valset,
                                 save_dir=os.path.join(args.model_dir, args.model_name), num_epoch=200)
    baseline.fit()
