import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import pickle

from utils import *
from models import *

class MyDataset(torch.utils.data.Dataset):
	def __init__(self, data, target, transform=None):
		self.data = data
		self.target = target
		self.transform = transform
		
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, index):
		x = self.data[index]
		if self.transform:
			x = self.transform(x)
		y = self.target[index]
		return x, y

def sampling(model, device, pool_loader, select_samples):
	model.eval()
	outputs, targets = [], []
    n_drops = 5
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			for i in range(5):
			    output.append(torch.softmax(model(data), dim=1))
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())
	
	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)
	
    pb = outputs.mean(0)
	entropy1 = (-pb*torch.log(pb)).sum(1)
	entropy2 = (-outputs*torch.log(outputs)).sum(2).mean(0)
	U = entropy2 - entropy1
	return np.argsort(U)[:select_samples]

def sampling_balance(pool_labels):
	out = []
	while True:
		idx = np.random.choice(pool_labels.shape[0], 1)[0]
		if idx not in out:
			count = 0
			for i_out in out:
				if pool_labels[i_out] == pool_labels[idx]:
					count += 1
			
			if count < 5:
				out.append(idx)
		
		if len(out) == 500:
			return np.array(out)

if __name__=="__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("--exp_idx", help="Index of experiment")
	args = parser.parse_args()
	fix_random_seed(int(args.exp_idx))
	device = 'cuda' 

	train_epochs = 30
	select_samples = 500
	train_data_path = 'data/cifar100_train_data.npy'
	train_labels_path = 'data/cifar100_train_labels.npy'
	test_data_path = 'data/cifar100_test_data.npy'
	test_labels_path = 'data/cifar100_test_labels.npy'

	net = ResNet(BasicBlock, [2, 2, 2, 2], 100)
	net = net.to(device)

	# net = torch.nn.DataParallel(net)
	# checkpoint = torch.load('./checkpoint/ckpt.pth')
	# net.load_state_dict(checkpoint['net'])

	transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
	])

	data = np.load(test_data_path) 
	labels = np.load(test_labels_path)
	testloader = torch.utils.data.DataLoader(MyDataset(data, labels, transform=transform), batch_size=100, shuffle=False)

	pool_data = np.load(train_data_path)
	pool_labels = np.load(train_labels_path)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)

	list_selected_data, list_selected_labels = [], []
	list_acc, list_ece = [], []
	
	for idx in range(99):
		idxs_unlabeled = sampling_balance(pool_labels)
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
	idxs_unlabeled = sampling_balance(pool_labels)
	
	for rd in range(100):
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		selected_data = np.asarray(list_selected_data)
		selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 32, 32, 3))
		selected_labels = np.asarray(list_selected_labels)
		selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
		trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)
		
		criterion = nn.CrossEntropyLoss()
		optimizer = optim.Adam(net.parameters())
		for epoch in range(train_epochs):
			train(net, device, trainloader, criterion, optimizer)

		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
		pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

		idxs_unlabeled = sampling(net, device, pool_loader, select_samples)

		outputs, targets, acc = test_model(net, device, testloader)
		ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		list_acc.append(acc)
		list_ece.append(ece)

		with open("out/cifar100/demo_bald" + str(args.exp_idx), "wb") as fp:
			pickle.dump(list_acc, fp)
			pickle.dump(list_ece, fp)