import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
import torch
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST
from torchvision.transforms import ToTensor, Compose, Normalize, Lambda

from utils.log import log

def set_data(data, gt, cfg):
	# Contamination Generaation Stage
	normal = cfg['normal']
	counter=Counter(gt.numpy())
	n_normal = counter[normal]
	total_abnormal = sum(counter.values()) - n_normal

	ct = cfg['ct'] + cfg['abnormal_ct']
	assert ct <= 1
	n_abnormal = ct/(1-ct)*n_normal
	ab_ratio = n_abnormal/total_abnormal

	indices= torch.where(gt==normal)

	# normal Contamination Generation
	# Some of normal class (proportional to normal ct) become labeled abnormal class
	n_data = data[indices]
	n_gt = gt[indices]
	normal_ct = cfg['normal_ct']
	n_abnormality=torch.zeros_like(n_gt)
	assert normal_ct >=0 and normal_ct <= 0.5
	if normal_ct != 0:
		indices= torch.where(n_gt==normal)
		indices=list(zip(*indices))
		indices, _ = train_test_split(indices, train_size=normal_ct)
		n_abnormality[indices] = 1

	# Contamination Generaation
	abindices= torch.where(gt!=normal)
	ab_gt= gt[abindices]
	abindices=list(zip(*abindices))
	if ab_ratio !=0:
		indices, _ = train_test_split(abindices, train_size=ab_ratio, stratify=ab_gt)
		indices = np.reshape(indices, (1,-1))
		ab_data = data[indices]
		ab_gt = gt[indices]

		data= torch.cat((n_data,ab_data),0)
		gt= torch.cat((n_gt,ab_gt),0)

		# Abnormal Contamination Generation
		# Some of labeled abnormal class (proportional to abnormal ct) become normal class
		abnormal_ct = cfg['abnormal_ct']/(cfg['abnormal_ct'] + cfg['ct'])
		ab_abnormality=torch.zeros_like(ab_gt)
		if abnormal_ct == 1:
			ab_abnormality=torch.ones_like(ab_gt)
		elif abnormal_ct != 0:
			indices= torch.where(ab_gt!=normal)
			indices=list(zip(*indices))
			indices, _ = train_test_split(indices, train_size=abnormal_ct, stratify=ab_gt)
			ab_abnormality[indices] = 1
		abnormality= torch.cat((n_abnormality,ab_abnormality),0)
	else:
		data= n_data
		gt= n_gt
		abnormality=n_abnormality
	return data, gt, abnormality


	
def make_data(cfg):
	try:
		dataset = cfg['dataset'].lower()
	except ValueError:
		raise ValueError("Dataset name is not included in Json file")
	
	normal = cfg['normal']

	# Load Dataset
	if dataset == "mnist":
		data= MNIST(root='./data', train=True, download=True)
		train_data, train_gt = data.data, data.targets
		data= MNIST(root='./data', train=False, download=True)
		test_data, test_gt= data.data, data.targets

		train_data = train_data.unsqueeze(-1)
		test_data = test_data.unsqueeze(-1)
		channel = 1

	elif dataset == "fashionmnist":
		data= FashionMNIST(root='./data', train=True, download=True)
		train_data, train_gt = data.data, data.targets
		data= FashionMNIST(root='./data', train=False, download=True)
		test_data, test_gt= data.data, data.targets

		train_data = train_data.unsqueeze(-1)
		test_data = test_data.unsqueeze(-1)
		channel = 1

	elif dataset == "cifar10":
		data= CIFAR10(root='./data', train=True, download=True)
		train_data, train_gt = data.data, data.targets
		data= CIFAR10(root='./data', train=False, download=True)
		test_data, test_gt= data.data, data.targets

		train_data = torch.from_numpy(train_data)
		train_gt= torch.from_numpy(np.array(train_gt))
		test_data = torch.from_numpy(test_data)
		test_gt= torch.from_numpy(np.array(test_gt))
		channel = 3


	else:
		raise ValueError(f"Unknown dataset name: {cfg['dataset']}")

	tmp_ct = cfg['ct']
	tmp_abct = cfg['abnormal_ct']
	cfg['ct']= 0.3
	cfg['abnormal_ct']= 0.0
	test_data, test_gt, test_abnormality = set_data(test_data, test_gt, cfg)
	cfg['ct']=tmp_ct
	cfg['abnormal_ct']=tmp_abct

	Ntest = test_data[torch.where(test_gt==cfg['normal'])].size()[0]
	train_data, valid_data, train_gt, valid_gt= train_test_split(train_data, train_gt, train_size=0.9, stratify=train_gt)
	Ntrain = train_data[torch.where(train_gt==cfg['normal'])].size()[0]
	train_size = 2*Ntest/Ntrain
	train_data, _, train_gt, _= train_test_split(train_data, train_gt, train_size=train_size, stratify=train_gt)

	train_data, train_gt, train_abnormality = set_data(train_data, train_gt, cfg)
	valid_data, valid_gt, valid_abnormality = set_data(valid_data, valid_gt, cfg)
	transform = Compose([
											ToTensor(),
										])

	# Log Information
	log(f"Normal Class: {normal}")
	log_str="Selectivity\n{0:>20}{1:>20}{2:>20}{3:>20}{4:>20}{5:>20}\n".format("NO.", "Total Train", "Normal Label", "Abnormal Label", "Valid", "Test")
	train_counter=Counter(train_gt.numpy())
	valid_counter=Counter(valid_gt.numpy())
	test_counter=Counter(test_gt.numpy())
	abnormality_counter=Counter(train_abnormality.numpy())

	gt_list = np.unique(list(np.unique(train_gt)) + list(np.unique(valid_gt))+list(np.unique(test_gt)))
	train_total=0
	valid_total=0
	test_total=0
	n_total=0
	ab_total=0
	for gt in gt_list:
		train_total +=train_counter[gt]
		valid_total +=valid_counter[gt]
		test_total +=test_counter[gt]
		indices = torch.where(train_gt==gt)
		tmp_abnormality = train_abnormality[indices]
		tmp_counter = Counter(tmp_abnormality.numpy())
		n_total += tmp_counter[0]
		ab_total += tmp_counter[1]
		log_str += f"{gt:>20}{train_counter[gt]:>20}{tmp_counter[0]:>20}{tmp_counter[1]:>20}{valid_counter[gt]:>20}{test_counter[gt]:>20}\n"
	log_str += f"{'Total':>20}{train_total:>20}{n_total:>20}{ab_total:>20}{valid_total:>20}{test_total:>20}\n"
	log_str += f"\n{'':>20}{'Train':>10}{'Valid':>10}{'Test':>10}"
	log_str += f"\n{'True Normal Total':>20}{train_counter[normal]:>10}{valid_counter[normal]:>10}{test_counter[normal]:>10}\n"
	log_str += f"{'True Abnormal Total':>20}{train_total-train_counter[normal]:>10}{valid_total-valid_counter[normal]:>10}{test_total-test_counter[normal]:>10}\n"
	log_str += f"{'Normal Like Abnormal':>20}{abnormality_counter[0] - train_counter[normal]:>10}{'':>10}{'':>10}\n"
	log_str += f"{'Labled Abnormal':>20}{abnormality_counter[1]:>10}{'':>10}{'':>10}\n"
	log_str += f"{'Train Abnormality 0':>20}{abnormality_counter[0]:>10}\n"
	log_str += f"{'Train Abnormality 1':>20}{abnormality_counter[1]:>10}\n"
	log_str += f"True Contamination(train): {(train_total-train_counter[normal]-abnormality_counter[1])/train_total*100:.3f} %\n"
	log(f"\n{log_str}")

	train_data = train_data.permute(0,3,1,2)/255.0
	valid_data = valid_data.permute(0,3,1,2)/255.0
	test_data = test_data.permute(0,3,1,2)/255.0

	train_gt = torch.where(train_gt==normal, 0, 1)
	valid_gt = torch.where(valid_gt==normal, 0, 1)
	test_gt = torch.where(test_gt==normal, 0, 1)
	test_abnormality= test_gt

	return train_data, train_gt, train_abnormality, valid_data, valid_gt, valid_abnormality, test_data, test_gt, test_abnormality, transform


