import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import pickle

from utils import *
from models import *

class MyDataset(torch.utils.data.Dataset):
	def __init__(self, data, target, transform=None):
		self.data = data
		self.target = target
		self.transform = transform
		
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, index):
		x = self.data[index]
		if self.transform:
			x = self.transform(x)
		y = self.target[index]
		return x, y

def sampling(model, device, pool_loader, select_samples):
	model.eval()
	outputs, targets = [], []
    n_drops = 5
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
            output = []
            for i in range(5):
			    output.append(torch.softmax(model(data), dim=1))
			# rand = torch.rand(output.shape[0], 1).to(device)
			# output = output * rand
			output = torch.softmax(output, dim=1)
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())
	
	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)

    pb = outputs.mean(0)
	entropy1 = (-pb*torch.log(pb)).sum(1)
	entropy2 = (-outputs*torch.log(outputs)).sum(2).mean(0)
	U = entropy2 - entropy1
	return np.argsort(U)[:select_samples]

def sampling_balance(pool_labels):
	out = []
	while True:
		idx = np.random.choice(pool_labels.shape[0], 1)[0]
		if idx not in out and pool_labels[idx] not in pool_labels[out]:
			out.append(idx)
		
		if len(out) == 10:
			return np.array(out)

if __name__=="__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("--exp_idx", help="Index of experiment")
	args = parser.parse_args()
	fix_random_seed(int(args.exp_idx))
	device = 'cuda' 

	train_epochs = 30
	select_samples = 10
	train_data_path = 'data/mnist_train_data.npy'
	train_labels_path = 'data/mnist_train_labels.npy'
	test_data_path = 'data/mnist_test_data.npy'
	test_labels_path = 'data/mnist_test_labels.npy'

	net = MNIST_Net()
	net = net.to(device)

	transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.1307,), (0.3081,))
	])

	data_c = np.load(test_data_path) 
	labels_c = np.load(test_labels_path)
	testloader_c = torch.utils.data.DataLoader(MyDataset(data_c, labels_c, transform=transform), batch_size=100, shuffle=False)

	pool_data = np.load(train_data_path)
	pool_labels = np.load(train_labels_path)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)
	
	list_selected_data, list_selected_labels = [], []
	list_acc, list_ece, list_uacc, list_uece = [], [], [], []
	idxs_unlabeled = sampling_balance(pool_labels)
	list_selected_data.append(pool_data[idxs_unlabeled])
	list_selected_labels.append(pool_labels[idxs_unlabeled])
	pool_data = np.delete(pool_data, idxs_unlabeled, 0)
	pool_labels = np.delete(pool_labels, idxs_unlabeled)
	idxs_unlabeled = sampling_balance(pool_labels)
	
	for rd in range(100):
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		selected_data = np.asarray(list_selected_data)
		selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 28, 28, 1))
		selected_labels = np.asarray(list_selected_labels)
		selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
		trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)
		
		criterion = nn.CrossEntropyLoss()
		optimizer = optim.Adam(net.parameters())
		for epoch in range(train_epochs):
			train(net, device, trainloader, criterion, optimizer)

		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
		pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

		idxs_unlabeled = sampling(net, device, pool_loader, select_samples)

		# outputs, targets, acc = test_model(net, device, pool_loader)
		# ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		# list_uacc.append(acc)
		# list_uece.append(ece)

		outputs, targets, acc = test_model(net, device, testloader_c)
		ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		list_acc.append(acc)
		list_ece.append(ece)

		with open("out/mnist/demo_bald" + str(args.exp_idx), "wb") as fp:
			# pickle.dump(list_uacc, fp)
			# pickle.dump(list_uece, fp)
			pickle.dump(list_acc, fp)
			pickle.dump(list_ece, fp)
			# pickle.dump(outputs, fp)
			# pickle.dump(targets, fp)