import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import pickle

from utils import *
from models import *

class MyDataset(torch.utils.data.Dataset):
	def __init__(self, data, target, transform=None):
		self.data = data
		self.target = target
		self.transform = transform
		
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, index):
		x = self.data[index]
		if self.transform:
			x = self.transform(x)
		y = self.target[index]
		return x, y

def sampling(model, device, pool_loader, select_samples, trainloader):
	model.eval()
	train_outputs, train_targets = torch.empty((0, 10)).to(device), torch.empty((0,), dtype=torch.int8).to(device)
	with torch.no_grad():
		for data, target in trainloader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			output = torch.softmax(output, dim=1)
			train_outputs = torch.cat((train_outputs, output), 0)
			train_targets = torch.cat((train_targets, target), 0)

	outputs = torch.empty((0, 10)).to(device)
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			output = torch.softmax(output, dim=1)
			outputs = torch.cat((outputs, output), 0)

	outputs_1 = outputs[:int(outputs.shape[0]/3)]
	input_func_1 = torch.cat((train_outputs, outputs_1), 0)
	ece_1 = get_ratio_canonical_per_samples(input_func_1, train_targets, bandwidth=0.001, p=1, device=device)

	outputs_2 = outputs[int(outputs.shape[0]/3):int(2*outputs.shape[0]/3)]
	input_func_2 = torch.cat((train_outputs, outputs_2), 0)
	ece_2 = get_ratio_canonical_per_samples(input_func_2, train_targets, bandwidth=0.001, p=1, device=device)

	outputs_3 = outputs[int(2*outputs.shape[0]/3):]
	input_func_3 = torch.cat((train_outputs, outputs_3), 0)
	ece_3 = get_ratio_canonical_per_samples(input_func_3, train_targets, bandwidth=0.001, p=1, device=device)

	ece = torch.cat((ece_1, ece_2), 0)
	ece = torch.cat((ece, ece_3), 0)
	out = torch.argsort(ece, descending = True)[:select_samples]

	return out.detach().cpu().numpy()

	# tmp = torch.cat((ece, ece_3), 0).detach().cpu().numpy()
	# normalized_tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min())
	# alpha = np.mean(normalized_tmp)
	# ece = -tmp
	# conf = outputs.max(1).values.detach().cpu().numpy()
	# normalized_ece = (ece - ece.min()) / (ece.max() - ece.min())
	# average_array = alpha * normalized_ece + (1 - alpha) * conf
	# return np.argsort(average_array)[:select_samples]

	# ece = torch.cat((ece, ece_3), 0).detach().cpu().numpy()
	# conf = outputs.max(1).values.detach().cpu().numpy()
	# out = np.column_stack((conf, ece))

	# return np.lexsort((out[:,0],-out[:,1]))[:select_samples]

def test_model(model, device, test_loader):
	model.eval()
	correct = 0
	outputs, targets = [], []
	with torch.no_grad():
		for data, target in test_loader:
			data, target = data.to(device), target.to(device)
			output = model(data)
			_, pred = torch.max(output.data, 1)
			correct += (pred == target).sum().item()
			output = torch.softmax(output, dim=1)
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())

	acc = 100. * correct / len(test_loader.dataset)
	print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
		correct, len(test_loader.dataset),
		100. * correct / len(test_loader.dataset)))

	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)
	
	tmps = np.asarray(targets[:len(targets)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1],)
	targets = np.concatenate((tmps, targets[len(targets)-1]), 0)

	return outputs, targets, acc

def sampling_balance(pool_labels):
	out = []
	while True:
		idx = np.random.choice(pool_labels.shape[0], 1)[0]
		if idx not in out:
			count = 0
			for i_out in out:
				if pool_labels[i_out] == pool_labels[idx]:
					count += 1
			
			if count < 5:
				out.append(idx)
		
		if len(out) == 50:
			return np.array(out)

if __name__=="__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("--exp_idx", help="Index of experiment")
	args = parser.parse_args()
	fix_random_seed(int(args.exp_idx))
	device = 'cuda' 

	train_epochs = 30
	select_samples = 50
	train_data_path = 'data/SVHN_train_data.npy'
	train_labels_path = 'data/SVHN_train_labels.npy'
	test_data_path = 'data/SVHN_test_data.npy'
	test_labels_path = 'data/SVHN_test_labels.npy'

	net = ResNet18()
	net = net.to(device)

	transform_train = transforms.Compose([
		transforms.ToTensor(),
		# transforms.RandomCrop(32, padding=4),
    	# transforms.RandomHorizontalFlip(),
		transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)),
	])

	transform_test = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)),
	])

	data = np.load(test_data_path) 
	labels = np.load(test_labels_path)
	testloader = torch.utils.data.DataLoader(MyDataset(data, labels, transform=transform_test), batch_size=100, shuffle=False)

	pool_data = np.load(train_data_path)
	pool_labels = np.load(train_labels_path)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform_test), batch_size=128, shuffle=False)

	list_selected_data, list_selected_labels = [], []
	list_acc, list_ece = [], []
	for idx in range(9):
		idxs_unlabeled = sampling_balance(pool_labels)
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)

	idxs_unlabeled = sampling_balance(pool_labels)

	for rd in range(100):
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		selected_data = np.asarray(list_selected_data)
		selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 32, 32, 3))
		selected_labels = np.asarray(list_selected_labels)
		selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
		trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform_train), batch_size=128, shuffle=True)
		
		criterion = nn.CrossEntropyLoss()
		optimizer = optim.Adam(net.parameters())
		for epoch in range(train_epochs):
			train(net, device, trainloader, criterion, optimizer)

		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
		pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform_test), batch_size=512, shuffle=False)

		trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform_test), batch_size=128, shuffle=True)
		idxs_unlabeled = sampling(net, device, pool_loader, select_samples, trainloader)

		outputs, targets, acc = test_model(net, device, testloader)
		ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		list_acc.append(acc)
		list_ece.append(ece)

		with open("out/svhn/ours" + str(args.exp_idx), "wb") as fp:
		# with open("out/svhn/ours_mix" + str(args.exp_idx), "wb") as fp:
			pickle.dump(list_acc, fp)
			pickle.dump(list_ece, fp)