import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import pickle

from utils import *
from models import *

class MyDataset(torch.utils.data.Dataset):
	def __init__(self, data, target, transform=None):
		self.data = data
		self.target = target
		self.transform = transform
		
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, index):
		x = self.data[index]
		if self.transform:
			x = self.transform(x)
		y = self.target[index]
		return x, y

import numpy as np
import pdb
from .strategy import Strategy
from sklearn.neighbors import NearestNeighbors
import pickle
from datetime import datetime
from sklearn.metrics import pairwise_distances

class CoreSet(Strategy):
    def __init__(self, X, Y, idxs_lb, net, handler, args, tor=1e-4):
        super(CoreSet, self).__init__(X, Y, idxs_lb, net, handler, args)
        self.tor = tor

    def furthest_first(self, X, X_set, n):
        m = np.shape(X)[0]
        if np.shape(X_set)[0] == 0:
            min_dist = np.tile(float("inf"), m)
        else:
            dist_ctr = pairwise_distances(X, X_set)
            min_dist = np.amin(dist_ctr, axis=1)

        idxs = []

        for i in range(n):
            idx = min_dist.argmax()
            idxs.append(idx)
            dist_new_ctr = pairwise_distances(X, X[[idx], :])
            for j in range(m):
                min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0])

        return idxs

    def sampling(self, n):
        t_start = datetime.now()
        idxs_unlabeled = np.arange(self.n_pool)[~self.idxs_lb]
        lb_flag = self.idxs_lb.copy()
        embedding = self.get_embedding(self.X, self.Y)
        embedding = embedding.numpy()

        chosen = self.furthest_first(embedding[idxs_unlabeled, :], embedding[lb_flag, :], n)

        return idxs_unlabeled[chosen]

def sampling_balance(pool_labels):
	out = []
	while True:
		idx = np.random.choice(pool_labels.shape[0], 1)[0]
		if idx not in out:
			count = 0
			for i_out in out:
				if pool_labels[i_out] == pool_labels[idx]:
					count += 1
			
			if count < 5:
				out.append(idx)
		
		if len(out) == 500:
			return np.array(out)

if __name__=="__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("--exp_idx", help="Index of experiment")
	args = parser.parse_args()
	fix_random_seed(int(args.exp_idx))
	device = 'cuda' 

	train_epochs = 30
	select_samples = 500
	train_data_path = 'data/cifar100_train_data.npy'
	train_labels_path = 'data/cifar100_train_labels.npy'
	test_data_path = 'data/cifar100_test_data.npy'
	test_labels_path = 'data/cifar100_test_labels.npy'

	net = ResNet(BasicBlock, [2, 2, 2, 2], 100)
	net = net.to(device)

	# net = torch.nn.DataParallel(net)
	# checkpoint = torch.load('./checkpoint/ckpt.pth')
	# net.load_state_dict(checkpoint['net'])

	transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
	])

	data = np.load(test_data_path) 
	labels = np.load(test_labels_path)
	testloader = torch.utils.data.DataLoader(MyDataset(data, labels, transform=transform), batch_size=100, shuffle=False)

	pool_data = np.load(train_data_path)
	pool_labels = np.load(train_labels_path)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)

	list_selected_data, list_selected_labels = [], []
	list_acc, list_ece = [], []
	
	for idx in range(99):
		idxs_unlabeled = sampling_balance(pool_labels)
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
	idxs_unlabeled = sampling_balance(pool_labels)
	
	for rd in range(100):
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		selected_data = np.asarray(list_selected_data)
		selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 32, 32, 3))
		selected_labels = np.asarray(list_selected_labels)
		selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))
		trainloader = torch.utils.data.DataLoader(MyDataset(selected_data, selected_labels, transform=transform), batch_size=128, shuffle=True)
		
		criterion = nn.CrossEntropyLoss()
		optimizer = optim.Adam(net.parameters())
		for epoch in range(train_epochs):
			train(net, device, trainloader, criterion, optimizer)

		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
		pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

		idxs_unlabeled = sampling(net, device, pool_loader, select_samples)

		outputs, targets, acc = test_model(net, device, testloader)
		ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		list_acc.append(acc)
		list_ece.append(ece)

		with open("out/cifar100/demo_coreset" + str(args.exp_idx), "wb") as fp:
			pickle.dump(list_acc, fp)
			pickle.dump(list_ece, fp)