import os
import argparse
import random

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load MNIST dataset
transform = transforms.Compose(
    [transforms.ToTensor(),])


def write_to_file(is_trained):
  trainset = torchvision.datasets.MNIST(root='./tmp', train=is_trained,
                                          download=True, transform=transform)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

  # Extract data and labels
  train_data = []
  train_labels = []
  for data, target in trainloader:
      data, target = data.to(device), target.to(device)
      train_data.append(data.cpu().numpy())
      train_labels.append(target.cpu().numpy())

  # Concatenate data and labels
  train_data = np.concatenate(train_data, axis=0)
  train_data = np.transpose(train_data, (0, 2, 3, 1))
  train_labels = np.concatenate(train_labels, axis=0)

  # Save to NumPy file
  if is_trained:
    np.save('tmp/mnist_train_data.npy', train_data)
    np.save('tmp/mnist_train_labels.npy', train_labels)
  else:
    np.save('tmp/mnist_test_data.npy', train_data)
    np.save('tmp/mnist_test_labels.npy', train_labels)

write_to_file(is_trained = True)
write_to_file(is_trained = False)

class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

def fix_random_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
        torch.cuda.manual_seed(seed_value)
        torch.backends.cudnn.enabled = False
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

def train(model, device, trainloader, criterion, optimizer):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def calculate_ece(softmaxes, labels, n_bins=10):
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)

    ece = torch.zeros(1, device=softmaxes.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.item()

def test_model(model, device, test_loader, calib_model = None):
    model.eval()
    correct = 0
    outputs, targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if calib_model is None:
                output = model(data)
            else:
                output = calib_model(model(data))
            _, pred = torch.max(output.data, 1)
            correct += (pred == target).sum().item()
            output = torch.softmax(output, dim=1)
            outputs.append(output.cpu().numpy())
            targets.append(target.cpu().numpy())

    acc = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    outputs = np.asarray(outputs)
    targets = np.asarray(targets)
    outputs = outputs.reshape(outputs.shape[0] * outputs.shape[1], outputs.shape[2])
    targets = targets.reshape(outputs.shape[0],)

    return outputs, targets, acc

class MyDataset(torch.utils.data.Dataset):
	def __init__(self, data, target, transform=None):
		self.data = data
		self.target = target
		self.transform = transform

	def __len__(self):
		return len(self.data)

	def __getitem__(self, index):
		x = self.data[index]
		if self.transform:
			x = self.transform(x)
		y = self.target[index]
		return x, y

def sampling_balance(pool_labels):
	out = []
	while True:
		idx = np.random.choice(pool_labels.shape[0], 1)[0]
		if idx not in out and pool_labels[idx] not in pool_labels[out]:
			out.append(idx)

		if len(out) == 10:
			return np.array(out)

def sampling(model, device, pool_loader, select_samples):
	model.eval()
	outputs, targets = [], []
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			output = torch.softmax(output, dim=1)
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())

	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)
	return np.random.choice(outputs.shape[0], select_samples, replace=False)

fix_random_seed(0)

device = 'cuda'

train_epochs = 30
select_samples = 10
train_data_path = 'tmp/mnist_train_data.npy'
train_labels_path = 'tmp/mnist_train_labels.npy'
test_data_path = 'tmp/mnist_test_data.npy'
test_labels_path = 'tmp/mnist_test_labels.npy'

net = MNIST_Net()
net = net.to(device)

transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.1307,), (0.3081,))
])

data_c = np.load(test_data_path)
labels_c = np.load(test_labels_path)
testloader_c = torch.utils.data.DataLoader(MyDataset(data_c, labels_c, transform=transform), batch_size=100, shuffle=False)

pool_data = np.load(train_data_path)
pool_labels = np.load(train_labels_path)
pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)

list_selected_data, list_selected_labels = [], []
list_acc_rd, list_ece_rd = [], []
idxs_unlabeled = sampling_balance(pool_labels)
list_selected_data.append(pool_data[idxs_unlabeled])
list_selected_labels.append(pool_labels[idxs_unlabeled])
pool_data = np.delete(pool_data, idxs_unlabeled, 0)
pool_labels = np.delete(pool_labels, idxs_unlabeled)
idxs_unlabeled = sampling_balance(pool_labels)

for rd in range(50):
	list_selected_data.append(pool_data[idxs_unlabeled])
	list_selected_labels.append(pool_labels[idxs_unlabeled])
	selected_data = np.asarray(list_selected_data)
	selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 28, 28, 1))
	selected_labels = np.asarray(list_selected_labels)
	selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
	trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)

	criterion = nn.CrossEntropyLoss()
	optimizer = optim.Adam(net.parameters())
	for epoch in range(train_epochs):
		train(net, device, trainloader, criterion, optimizer)

	pool_data = np.delete(pool_data, idxs_unlabeled, 0)
	pool_labels = np.delete(pool_labels, idxs_unlabeled)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

	idxs_unlabeled = sampling(net, device, pool_loader, select_samples)

	outputs, targets, acc = test_model(net, device, testloader_c)
	ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
	list_acc_rd.append(acc)
	list_ece_rd.append(ece)

def sampling(model, device, pool_loader, select_samples):
	model.eval()
	outputs, targets = [], []
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			rand = torch.rand(output.shape[0], 1).to(device)
			output = output * rand
			output = torch.softmax(output, dim=1)
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())

	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)
	return np.argsort(outputs.max(1))[:select_samples]

fix_random_seed(0)
net = MNIST_Net()
net = net.to(device)

data_c = np.load(test_data_path)
labels_c = np.load(test_labels_path)
testloader_c = torch.utils.data.DataLoader(MyDataset(data_c, labels_c, transform=transform), batch_size=100, shuffle=False)

pool_data = np.load(train_data_path)
pool_labels = np.load(train_labels_path)
pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)

list_selected_data, list_selected_labels = [], []
list_acc_uerm, list_ece_uerm = [], []
idxs_unlabeled = sampling_balance(pool_labels)
list_selected_data.append(pool_data[idxs_unlabeled])
list_selected_labels.append(pool_labels[idxs_unlabeled])
pool_data = np.delete(pool_data, idxs_unlabeled, 0)
pool_labels = np.delete(pool_labels, idxs_unlabeled)
idxs_unlabeled = sampling_balance(pool_labels)

for rd in range(50):
	list_selected_data.append(pool_data[idxs_unlabeled])
	list_selected_labels.append(pool_labels[idxs_unlabeled])
	selected_data = np.asarray(list_selected_data)
	selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 28, 28, 1))
	selected_labels = np.asarray(list_selected_labels)
	selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
	trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)

	criterion = nn.CrossEntropyLoss()
	optimizer = optim.Adam(net.parameters())
	for epoch in range(train_epochs):
		train(net, device, trainloader, criterion, optimizer)

	pool_data = np.delete(pool_data, idxs_unlabeled, 0)
	pool_labels = np.delete(pool_labels, idxs_unlabeled)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

	idxs_unlabeled = sampling(net, device, pool_loader, select_samples)

	outputs, targets, acc = test_model(net, device, testloader_c)
	ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
	list_acc_uerm.append(acc)
	list_ece_uerm.append(ece)

def sampling(model, device, pool_loader, select_samples):
	model.eval()
	outputs, targets = [], []
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			output = torch.softmax(output, dim=1)
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())

	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)
	return np.argsort(outputs.max(1))[:select_samples]

fix_random_seed(0)
net = MNIST_Net()
net = net.to(device)

data_c = np.load(test_data_path)
labels_c = np.load(test_labels_path)
testloader_c = torch.utils.data.DataLoader(MyDataset(data_c, labels_c, transform=transform), batch_size=100, shuffle=False)

pool_data = np.load(train_data_path)
pool_labels = np.load(train_labels_path)
pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)

list_selected_data, list_selected_labels = [], []
list_acc_erm, list_ece_erm = [], []
idxs_unlabeled = sampling_balance(pool_labels)
list_selected_data.append(pool_data[idxs_unlabeled])
list_selected_labels.append(pool_labels[idxs_unlabeled])
pool_data = np.delete(pool_data, idxs_unlabeled, 0)
pool_labels = np.delete(pool_labels, idxs_unlabeled)
idxs_unlabeled = sampling_balance(pool_labels)

for rd in range(50):
	list_selected_data.append(pool_data[idxs_unlabeled])
	list_selected_labels.append(pool_labels[idxs_unlabeled])
	selected_data = np.asarray(list_selected_data)
	selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 28, 28, 1))
	selected_labels = np.asarray(list_selected_labels)
	selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
	trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)

	criterion = nn.CrossEntropyLoss()
	optimizer = optim.Adam(net.parameters())
	for epoch in range(train_epochs):
		train(net, device, trainloader, criterion, optimizer)

	pool_data = np.delete(pool_data, idxs_unlabeled, 0)
	pool_labels = np.delete(pool_labels, idxs_unlabeled)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

	idxs_unlabeled = sampling(net, device, pool_loader, select_samples)

	outputs, targets, acc = test_model(net, device, testloader_c)
	ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
	list_acc_erm.append(acc)
	list_ece_erm.append(ece)

def get_ratio_canonical_per_samples(f, y, bandwidth, p, device):
    if f.shape[1] > 60:
        # Slower but more numerically stable implementation for larger number of classes
        return get_ratio_canonical_log(f, y, bandwidth, p, device)

    log_kern = get_kernel(f, bandwidth, device)
    kern = torch.exp(log_kern)

    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
    #Modiy here
    kern_y_splits = kern[y.shape[0]:, :y.shape[0]]
    # kern_y_splits_2 = kern[y.shape[0]:, y.shape[0]:]
    ###

    kern_y = torch.matmul(kern_y_splits, y_onehot)
    den = torch.sum(kern_y_splits, dim=1)
    # den = torch.sum(kern_y_splits_2, dim=1) - torch.diagonal(kern_y_splits_2, 0)
    # to avoid division by 0
    den = torch.clamp(den, min=1e-10)

    ratio = kern_y / den.unsqueeze(-1)
    ce_per_samples = torch.sum(torch.abs(ratio - f[y.shape[0]:])**p, dim=1)

    return ce_per_samples


# Note for training: Make sure there are at least two examples for every class present in the batch, otherwise
# LogsumexpBackward returns nans.
def get_ratio_canonical_log(f, y, bandwidth, p, device='cpu'):
    log_kern = get_kernel(f, bandwidth, device)
    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
    log_y = torch.log(y_onehot)
    log_den = torch.logsumexp(log_kern, dim=1)
    final_ratio = 0
    for k in range(f.shape[1]):
        log_kern_y = log_kern + (torch.ones([f.shape[0], 1]) * log_y[:, k].unsqueeze(0))
        log_inner_ratio = torch.logsumexp(log_kern_y, dim=1) - log_den
        inner_ratio = torch.exp(log_inner_ratio)
        inner_diff = torch.abs(inner_ratio - f[:, k])**p
        final_ratio += inner_diff

    return torch.mean(final_ratio)

def get_kernel(f, bandwidth, device):
    # if num_classes == 1
    if f.shape[1] == 1:
        log_kern = beta_kernel(f, f, bandwidth).squeeze()
    else:
        log_kern = dirichlet_kernel(f, bandwidth).squeeze()
    # Trick: -inf on the diagonal
    return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)


def beta_kernel(z, zi, bandwidth=0.1):
    p = zi / bandwidth + 1
    q = (1-zi) / bandwidth + 1
    z = z.unsqueeze(-2)

    log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q)
    log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z)
    log_beta_pdf = log_num - log_beta

    return log_beta_pdf


def dirichlet_kernel(z, bandwidth=0.1):
    z = torch.clamp(z, min=1e-10)
    alphas = z / bandwidth + 1

    log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1)))
    log_num = torch.matmul(torch.log(z), (alphas-1).T)
    log_dir_pdf = log_num - log_beta

    return log_dir_pdf

def sampling(model, device, pool_loader, select_samples, trainloader):
	model.eval()
	train_outputs, train_targets = torch.empty((0, 10)).to(device), torch.empty((0,), dtype=torch.int8).to(device)
	with torch.no_grad():
		for data, target in trainloader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			output = torch.softmax(output, dim=1)
			train_outputs = torch.cat((train_outputs, output), 0)
			train_targets = torch.cat((train_targets, target), 0)

	outputs = torch.empty((0, 10)).to(device)
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			output = torch.softmax(output, dim=1)
			outputs = torch.cat((outputs, output), 0)

	outputs_1 = outputs[:int(outputs.shape[0]/2)]
	input_func_1 = torch.cat((train_outputs, outputs_1), 0)
	ece_1 = get_ratio_canonical_per_samples(input_func_1, train_targets, bandwidth=0.001, p=1, device=device)

	outputs_2 = outputs[int(outputs.shape[0]/2):]
	input_func_2 = torch.cat((train_outputs, outputs_2), 0)
	ece_2 = get_ratio_canonical_per_samples(input_func_2, train_targets, bandwidth=0.001, p=1, device=device)

	ece = torch.cat((ece_1, ece_2), 0).detach().cpu().numpy()
	conf = outputs.max(1).values.detach().cpu().numpy()
	out = np.column_stack((conf, ece))

	return np.lexsort((out[:,0],-out[:,1]))[:select_samples]

fix_random_seed(0)

net = MNIST_Net()
net = net.to(device)

data_c = np.load(test_data_path)
labels_c = np.load(test_labels_path)
testloader_c = torch.utils.data.DataLoader(MyDataset(data_c, labels_c, transform=transform), batch_size=100, shuffle=False)

pool_data = np.load(train_data_path)
pool_labels = np.load(train_labels_path)
pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)

list_selected_data, list_selected_labels = [], []
list_acc_ours, list_ece_ours = [], []
idxs_unlabeled = sampling_balance(pool_labels)
list_selected_data.append(pool_data[idxs_unlabeled])
list_selected_labels.append(pool_labels[idxs_unlabeled])
pool_data = np.delete(pool_data, idxs_unlabeled, 0)
pool_labels = np.delete(pool_labels, idxs_unlabeled)
idxs_unlabeled = sampling_balance(pool_labels)

for rd in range(50):
	list_selected_data.append(pool_data[idxs_unlabeled])
	list_selected_labels.append(pool_labels[idxs_unlabeled])
	selected_data = np.asarray(list_selected_data)
	selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 28, 28, 1))
	selected_labels = np.asarray(list_selected_labels)
	selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
	trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)

	criterion = nn.CrossEntropyLoss()
	optimizer = optim.Adam(net.parameters())
	for epoch in range(train_epochs):
		train(net, device, trainloader, criterion, optimizer)

	pool_data = np.delete(pool_data, idxs_unlabeled, 0)
	pool_labels = np.delete(pool_labels, idxs_unlabeled)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

	idxs_unlabeled = sampling(net, device, pool_loader, select_samples, trainloader)

	outputs, targets, acc = test_model(net, device, testloader_c)
	ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
	list_acc_ours.append(acc)
	list_ece_ours.append(ece)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].plot(list_ece_rd, marker='.', label = "Random")
axs[0].plot(list_ece_uerm, marker='.', label = "Uncalibrated-least-conf", color = "#bcbd22")
axs[0].plot(list_ece_erm, marker='.', label = "Least-conf", color = "#ff7f0e")
axs[0].plot(list_ece_ours, marker='.', label = "Ours", color = "blue")
axs[0].set_title('Expected Calibration Error')
axs[0].grid(True)

axs[1].plot(list_acc_rd, marker='.', label = "Random")
axs[1].plot(list_acc_uerm, marker='.', label = "Uncalibrated-least-conf", color = "#bcbd22")
axs[1].plot(list_acc_erm, marker='.', label = "Least-conf", color = "#ff7f0e")
axs[1].plot(list_acc_ours, marker='.', label = "Ours", color = "blue")
axs[1].set_title('Accuracy')
axs[1].grid(True)

plt.legend()
plt.tight_layout()
plt.savefig("out.png")

# with open("tmp/demo5", "wb") as fp:
# 	pickle.dump(list_ece_rd, fp)
# 	pickle.dump(list_ece_uerm, fp)
# 	pickle.dump(list_ece_erm, fp)
# 	pickle.dump(list_ece_ours, fp)
# 	pickle.dump(list_acc_rd, fp)
# 	pickle.dump(list_acc_uerm, fp)
# 	pickle.dump(list_acc_erm, fp)
# 	pickle.dump(list_acc_ours, fp)