import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split

import matplotlib.pyplot as plt
import numpy as np
import pickle

from utils import *
from models import *

class MyDataset(torch.utils.data.Dataset):
	def __init__(self, data, target, transform=None):
		self.data = data
		self.target = target
		self.transform = transform
		
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, index):
		x = self.data[index]
		if self.transform:
			x = self.transform(x)
		y = self.target[index]
		return x, y

class CalibrationModel(nn.Module):
    def __init__(self):
        super(CalibrationModel, self).__init__()
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, logits):
        return logits / self.temperature

def calibrate_model(calib_model, model, valid_loader):
    model.eval()
    epochs = 30
    optimizer = optim.Adam(calib_model.parameters(), lr = 0.001)

    for epoch in range(1, epochs):
        criterion = nn.CrossEntropyLoss()
        calib_model.train()
        for batch_idx, (data, target) in enumerate(valid_loader):
            target = target.type(torch.LongTensor)
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = calib_model(model(data))
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            # if batch_idx % 10 == 0:
            #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            #         epoch, batch_idx * len(data), len(valid_loader.dataset),
            #         100. * batch_idx / len(valid_loader), loss.item()))

def sampling(model, device, pool_loader, select_samples, calib_model = None):
	model.eval()
	outputs, targets = [], []
	with torch.no_grad():
		for data, target in pool_loader:
			data, target = data.to(device), target.to(device)
			output = calib_model(model(data))
			# rand = torch.rand(output.shape[0], 1).to(device)
			# output = output * rand
			output = torch.softmax(output, dim=1)
			outputs.append(output.cpu().numpy())
			targets.append(target.cpu().numpy())
	
	tmps = np.asarray(outputs[:len(outputs)-1])
	tmps = tmps.reshape(tmps.shape[0] * tmps.shape[1], tmps.shape[2])
	outputs = np.concatenate((tmps, outputs[len(outputs)-1]), 0)
	return np.argsort(outputs.max(1))[:select_samples]

def sampling_balance(pool_labels):
	out = []
	while True:
		idx = np.random.choice(pool_labels.shape[0], 1)[0]
		if idx not in out and pool_labels[idx] not in pool_labels[out]:
			out.append(idx)
		
		if len(out) == 10:
			return np.array(out)

if __name__=="__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("--exp_idx", help="Index of experiment")
	args = parser.parse_args()
	fix_random_seed(int(args.exp_idx))
	device = 'cuda' 

	train_epochs = 30
	select_samples = 10
	train_data_path = 'data/mnist_train_data.npy'
	train_labels_path = 'data/mnist_train_labels.npy'
	test_data_path = 'data/mnist_test_data.npy'
	test_labels_path = 'data/mnist_test_labels.npy'

	net = MNIST_Net()
	net = net.to(device)

	transform = transforms.Compose([
		transforms.ToTensor(),
		transforms.Normalize((0.1307,), (0.3081,))
	])

	data_c = np.load(test_data_path) 
	labels_c = np.load(test_labels_path)
	testloader_c = torch.utils.data.DataLoader(MyDataset(data_c, labels_c, transform=transform), batch_size=100, shuffle=False)

	pool_data = np.load(train_data_path)
	pool_labels = np.load(train_labels_path)
	pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=128, shuffle=False)
	
	list_selected_data, list_selected_labels = [], []
	list_acc, list_ece, list_uacc, list_uece = [], [], [], []
	idxs_unlabeled = sampling_balance(pool_labels)
	list_selected_data.append(pool_data[idxs_unlabeled])
	list_selected_labels.append(pool_labels[idxs_unlabeled])
	pool_data = np.delete(pool_data, idxs_unlabeled, 0)
	pool_labels = np.delete(pool_labels, idxs_unlabeled)
	idxs_unlabeled = sampling_balance(pool_labels)
	
	for rd in range(100):
		list_selected_data.append(pool_data[idxs_unlabeled])
		list_selected_labels.append(pool_labels[idxs_unlabeled])
		selected_data = np.asarray(list_selected_data)
		selected_data = np.reshape(selected_data, (selected_data.shape[0]*select_samples, 28, 28, 1))
		selected_labels = np.asarray(list_selected_labels)
		selected_labels = np.reshape(selected_labels, (selected_labels.shape[0]*select_samples,))

		# Split the dataset
		idx_split = int(0.8*len(selected_data))
		train_data, val_data = selected_data[:idx_split], selected_data[idx_split:]
		train_labels, val_labels = selected_labels[:idx_split], selected_labels[idx_split:]

		trainloader = torch.utils.data.DataLoader(MyDataset(train_data, train_labels, transform=transform), batch_size=128, shuffle=True)
		val_loader = torch.utils.data.DataLoader(MyDataset(val_data, val_labels, transform=transform), batch_size=len(val_data))

		criterion = nn.CrossEntropyLoss()
		optimizer = optim.Adam(net.parameters())
		for epoch in range(train_epochs):
			train(net, device, trainloader, criterion, optimizer)

		ts_calib_model = CalibrationModel().to(device)
		calibrate_model(ts_calib_model, net, val_loader)
		ts_calib_model.eval()

		pool_data = np.delete(pool_data, idxs_unlabeled, 0)
		pool_labels = np.delete(pool_labels, idxs_unlabeled)
		pool_loader = torch.utils.data.DataLoader(MyDataset(pool_data, pool_labels, transform=transform), batch_size=512, shuffle=False)

		idxs_unlabeled = sampling(net, device, pool_loader, select_samples, ts_calib_model)

		outputs, targets, acc = test_model(net, device, pool_loader, ts_calib_model)
		ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		list_uacc.append(acc)
		list_uece.append(ece)

		outputs, targets, acc = test_model(net, device, testloader_c, ts_calib_model)
		ece = calculate_ece(torch.tensor(outputs), torch.tensor(targets))
		list_acc.append(acc)
		list_ece.append(ece)

		with open("out/mnist/demo_ERM_cal" + str(args.exp_idx), "wb") as fp:
			pickle.dump(list_uacc, fp)
			pickle.dump(list_uece, fp)
			pickle.dump(list_acc, fp)
			pickle.dump(list_ece, fp)
			# pickle.dump(outputs, fp)
			# pickle.dump(targets, fp)