import argparse
import os,sys
import numpy as np
import torch
import torch.nn as nn
sys.path.append('../')
sys.path.append(os.getcwd())

from pprint import  pformat
import yaml
import logging
import time
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from defense.base import defense
import scipy
from utils.aggregate_block.train_settings_generate import argparser_criterion, argparser_opt_scheduler
from utils.trainer_cls import PureCleanModelTrainer
from utils.aggregate_block.fix_random import fix_random
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.log_assist import get_git_info
from utils.aggregate_block.dataset_and_transform_generate import get_input_shape, get_num_classes, get_transform
from utils.save_load_attack import load_attack_result, save_defense_result
from utils.nCHW_nHWC import *

import tqdm
import heapq
from PIL import Image
from utils.bd_dataset_v2 import dataset_wrapper_with_transform,xy_iter, prepro_cls_DatasetBD_v2
from utils.trainer_cls import Metric_Aggregator, PureCleanModelTrainer, all_acc, general_plot_for_epoch, given_dataloader_test
from collections import Counter
import copy
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import csv
from sklearn import metrics

import math

def get_features_labels(args, model, target_layer, data_loader):

	def feature_hook(module, input_, output_):
		global feature_vector
		feature_vector = output_
		return None

	h = target_layer.register_forward_hook(feature_hook)

	model.eval()
	features = []
	labels = []

	with torch.no_grad():
		for batch_idx, (inputs, targets, *other_info) in enumerate(data_loader):
			global feature_vector
			inputs, targets = inputs.to(args.device), targets.to(args.device)
			outputs = model(inputs)
			if args.model != "densenet161" and args.model != "mobilenet_v2":
				if feature_vector.ndim == 2:
					feature_vector = feature_vector
				else:
					feature_vector = torch.sum(torch.flatten(feature_vector, 2), 2)
			elif args.model == "densenet161":
				feature_vector = torch.nn.functional.relu(feature_vector)
				feature_vector = torch.nn.functional.adaptive_avg_pool2d(feature_vector, (1, 1))
				feature_vector = torch.sum(torch.flatten(feature_vector, 2), 2)
			elif args.model == "mobilenet_v2":
				feature_vector = torch.nn.functional.adaptive_avg_pool2d(feature_vector, (1, 1))
				feature_vector = torch.sum(torch.flatten(feature_vector, 2), 2)
			current_feature = feature_vector.detach().cpu().numpy()
			current_labels = targets.cpu().numpy()

			# Store features
			features.append(current_feature)
			labels.append(current_labels)

	features = np.concatenate(features, axis=0)
	labels = np.concatenate(labels, axis=0)
	h.remove()  # Rmove the hook

	return features, labels




class samde(defense):

	def __init__(self,args):
		with open(args.yaml_path, 'r') as f:
			defaults = yaml.safe_load(f)

		defaults.update({k:v for k,v in args.__dict__.items() if v is not None})

		args.__dict__ = defaults

		args.terminal_info = sys.argv

		args.num_classes = get_num_classes(args.dataset)
		args.input_height, args.input_width, args.input_channel = get_input_shape(args.dataset)
		args.img_size = (args.input_height, args.input_width, args.input_channel)
		args.dataset_path = f"{args.dataset_path}/{args.dataset}"
		
		self.args = args

		if 'result_file' in args.__dict__ :
			if args.result_file is not None:
				self.set_result(args.result_file)

	def add_arguments(parser):
		parser.add_argument('--device', type=str, help='cuda, cpu')
		parser.add_argument("-pm","--pin_memory", type=lambda x: str(x) in ['True', 'true', '1'], help = "dataloader pin_memory")
		parser.add_argument("-nb","--non_blocking", type=lambda x: str(x) in ['True', 'true', '1'], help = ".to(), set the non_blocking = ?")
		parser.add_argument("-pf", '--prefetch', type=lambda x: str(x) in ['True', 'true', '1'], help='use prefetch')
		parser.add_argument('--amp', default = False, type=lambda x: str(x) in ['True','true','1'])

		parser.add_argument('--checkpoint_load', type=str, help='the location of load model')
		parser.add_argument('--checkpoint_save', type=str, help='the location of checkpoint where model is saved')
		parser.add_argument('--log', type=str, help='the location of log')
		parser.add_argument("--dataset_path", type=str, help='the location of data')
		parser.add_argument('--dataset', type=str, help='mnist, cifar10, cifar100, gtrsb, tiny') 
		parser.add_argument('--result_file', type=str, help='the location of result')
	
		parser.add_argument('--epochs', type=int)
		parser.add_argument('--batch_size', type=int)
		parser.add_argument("--num_workers", type=float)
		parser.add_argument('--lr', type=float)
		parser.add_argument('--lr_scheduler', type=str, help='the scheduler of lr')
		parser.add_argument('--steplr_stepsize', type=int)
		parser.add_argument('--steplr_gamma', type=float)
		parser.add_argument('--steplr_milestones', type=list)
		parser.add_argument('--model', type=str, help='resnet18')
		
		parser.add_argument('--client_optimizer', type=int)
		parser.add_argument('--sgd_momentum', type=float)
		parser.add_argument('--wd', type=float, help='weight decay of sgd')
		parser.add_argument('--frequency_save', type=int,
						help=' frequency_save, 0 is never')

		parser.add_argument('--random_seed', type=int, help='random seed')
		parser.add_argument('--yaml_path', type=str, default="./config/detection/samde/cifar10.yaml", help='the path of yaml')
		parser.add_argument('--clean_sample_num', type=int)

		parser.add_argument('--target_layer', type=str, default='avg_pool')
		parser.add_argument('--dimension', type=int, default=30)
		parser.add_argument('--epsilon', type=float, default=10)

		parser.add_argument('--var_module_loc', type=str, default='/detection/variance', help='the location of variance module')
		

	def set_result(self, result_file):
		attack_file = 'record/' + result_file
		save_path = 'record/' + result_file + '/detection/samde_pretrain/'
		if not (os.path.exists(save_path)):
			os.makedirs(save_path) 
		self.args.var_module_path = 'record/' + result_file + self.args.var_module_loc
		if not (os.path.exists(self.args.var_module_path)):
			os.makedirs(self.args.var_module_path)
		self.args.save_path = save_path
		if self.args.checkpoint_save is None:
			self.args.checkpoint_save = save_path + 'detection_info/'
			if not (os.path.exists(self.args.checkpoint_save)):
				os.makedirs(self.args.checkpoint_save) 
				
		if self.args.log is None:
			self.args.log = save_path + 'log/'
			if not (os.path.exists(self.args.log)):
				os.makedirs(self.args.log)
		self.result = load_attack_result(attack_file + '/attack_result.pt')

	def set_trainer(self, model):
		self.trainer = PureCleanModelTrainer(
			model = model,
		)

	def set_logger(self):
		args = self.args
		logFormatter = logging.Formatter(
			fmt='%(asctime)s [%(levelname)-8s] [%(filename)s:%(lineno)d] %(message)s',
			datefmt='%Y-%m-%d:%H:%M:%S',
		)
		logger = logging.getLogger()

		fileHandler = logging.FileHandler(args.log + '/' + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + '.log')
		fileHandler.setFormatter(logFormatter)
		logger.addHandler(fileHandler)

		consoleHandler = logging.StreamHandler()
		consoleHandler.setFormatter(logFormatter)
		logger.addHandler(consoleHandler)

		logger.setLevel(logging.INFO)
		logging.info(pformat(args.__dict__))

		try:
			logging.info(pformat(get_git_info()))
		except:
			logging.info('Getting git info fails.')
	
	def set_devices(self):
		self.device = self.args.device
	
	def cal(self, true, pred):
		TN, FP, FN, TP = confusion_matrix(true, pred).ravel()
		return TN, FP, FN, TP 
	def metrix(self, TN, FP, FN, TP):
		TPR = TP/(TP+FN)
		FPR = FP/(FP+TN)
		precision = TP/(TP+FP)
		acc = (TP+TN)/(TN+FP+FN+TP)
		return TPR, FPR, precision, acc
	def filtering(self):
		start = time.perf_counter()
		self.set_devices()
		fix_random(self.args.random_seed)

		### a. load model, bd train data and transforms
		model = generate_cls_model(self.args.model,self.args.num_classes)
		model.load_state_dict(self.result['model'])
		if "," in self.device:
			model = torch.nn.DataParallel(
				model,
				device_ids=[int(i) for i in self.args.device[5:].split(",")]  # eg. "cuda:2,3,7" -> [2,3,7]
			)
			self.args.device = f'cuda:{model.device_ids[0]}'
			model.to(self.args.device)
			model.eval()
		else:
			model.to(self.args.device)
			model.eval()
		
		test_tran = get_transform(self.args.dataset, *([self.args.input_height,self.args.input_width]) , train = False)
		bd_train_dataset = self.result['bd_train'].wrapped_dataset
		pindex = np.where(np.array(bd_train_dataset.poison_indicator) == 1)[0]

		module_dict = dict(model.named_modules())
		target_layer = module_dict[args.target_layer]

		clean_test_dataset = self.result['clean_test'].wrapped_dataset

		### b. find a clean sample from test dataset
		images = []
		labels = []
		for img, label in clean_test_dataset:
			images.append(img)
			labels.append(label)
		class_idx_whole = []
		print(self.args.clean_sample_num)
		print(self.args.num_classes)
		num = int(self.args.clean_sample_num / self.args.num_classes)
		if num == 0:
			num = 1
		for i in range(self.args.num_classes):
			class_idx_whole.append(np.random.choice(np.where(np.array(labels)==i)[0], num))
		class_idx_whole = np.concatenate(class_idx_whole, axis=0)
		image_c = [images[i] for i in class_idx_whole]
		label_c = [labels[i] for i in class_idx_whole]
		logging.info("clean sample num: {}".format(len(image_c)))
		clean_dataset = xy_iter(image_c, label_c,transform=test_tran)
		clean_dataloader = DataLoader(clean_dataset, self.args.batch_size, shuffle=True)
		clean_features,clean_labels = get_features_labels(args, model, target_layer, clean_dataloader)
		logging.info("get clean feature")
		### c. load training dataset with poison samples
		#images_poison = []
		#labels_poison = []
		#for img, label, *other_info in bd_train_dataset:
		#	images_poison.append(img)
		#	labels_poison.append(label)

		#### d. get features of training dataset
		#train_dataset = xy_iter(images_poison, labels_poison,transform=test_tran)
		train_dataset = self.result['bd_train']
		train_dataset.wrapped_dataset.transform = test_tran
		train_dataloader = DataLoader(train_dataset, self.args.batch_size, shuffle=False)
		train_features, train_labels = get_features_labels(args, model, target_layer, train_dataloader)
		
		feats_inspection = np.array(train_features)
		class_indices_inspection = np.array(train_labels)

		feats_clean = np.array(clean_features)
		class_indices_clean = np.array(clean_labels)

		

		suspicious_indices = []
		flag_list = []
		
		projection_matrix = []
		svd_matrix = []
		for target_class in range(args.num_classes):
			
			iter = 0
			feature_class_inspection_loc = np.where(class_indices_inspection == target_class)[0]
			if len(feature_class_inspection_loc) == 0:
				continue
			feature_class_inspection = feats_inspection[class_indices_inspection == target_class]
			feature_class_clean = feats_clean[class_indices_clean == target_class]
			### calculate how many feature_class_inspection_loc in the pindex and how many is not in the pindex
			pindex_in_inspection_all = len(np.intersect1d(feature_class_inspection_loc, pindex))
			pindex_not_in_inspection_all = len(np.setdiff1d(feature_class_inspection_loc, pindex))
			## svd to get the top 7 eigenvectors
			reduce_dimension = True
			if reduce_dimension:
				feature_all = np.concatenate((feature_class_inspection, feature_class_clean), axis=0)
				if target_class == 0:
					feature_path = os.path.join(args.var_module_path, f'{args.target_layer}_{args.dimension}_feature_class_inspection.npy')
					np.save(feature_path, feature_all)
					logging.info(f"Feature saved to {feature_path}")
				mean = np.mean(feature_all, axis=0)
				feature_all -= mean
				## svd
				U, S, V = np.linalg.svd(feature_all, full_matrices=False)
				## get the top 7 eigenvectors
				dim = args.dimension
				eigs = V[0:dim]
				feature_class_inspection = np.dot(feature_class_inspection, eigs.T)
				feature_class_clean = np.dot(feature_class_clean, eigs.T)
			  
			cluster_0_indices = []
			cluster_1_indices = []
			
			epsilon = args.epsilon
			while iter < 10:
				
				# Initialize the cluster centers
				cluster_0_center = np.mean(feature_class_clean, axis=0)
				cluster_0_cov = np.cov(feature_class_clean, rowvar=False)
				# calculate the Mahalanobis distance
				mahalanobis_distance = scipy.spatial.distance.cdist(feature_class_inspection, [cluster_0_center], 'mahalanobis', VI=np.linalg.pinv(cluster_0_cov))
				# update the cluster center
				if iter == 0:
					#distances = scipy.spatial.distance.cdist(feature_class_inspection, [cluster_0_center], 'euclidean')
					#distances_percentile = np.percentile(distances, 5)
					#cluster_0_indices = np.where(distances <= distances_percentile)[0]
					mahalanobis_distance_percentile = np.percentile(mahalanobis_distance, 10)
					cluster_0_indices = np.where(mahalanobis_distance <= mahalanobis_distance_percentile)[0]
				else:
					epsilon_t = epsilon/10*(iter+1)
					cluster_0_indices = np.where(mahalanobis_distance <= epsilon_t)[0]
				feature_class_clean = np.concatenate((feature_class_clean, feature_class_inspection[cluster_0_indices]), axis=0)
				feature_class_inspection = np.delete(feature_class_inspection, cluster_0_indices, axis=0)
				feature_class_inspection_loc = np.delete(feature_class_inspection_loc, cluster_0_indices, axis=0)
				# calculate how many feature_class_inspection_loc in the pindex and how many is not in the pindex
				pindex_in_inspection = len(np.intersect1d(feature_class_inspection_loc, pindex))
				pindex_not_in_inspection = len(np.setdiff1d(feature_class_inspection_loc, pindex))
				# calculate the tpr and fpr
				tpr = pindex_in_inspection / pindex_in_inspection_all if pindex_in_inspection_all != 0 else 0
				fpr = pindex_not_in_inspection / pindex_not_in_inspection_all
				logging.info("target class: {}, iter: {}, tpr: {}, fpr: {}".format(target_class, iter, tpr, fpr))
				iter += 1

			if len(feature_class_inspection) != 0:
				suspicious_indices.extend(feature_class_inspection_loc)	
				logging.info("target class: {}, suspicious indices: {}".format(target_class, len(feature_class_inspection_loc)))

			projection_matrix.append(np.dot(eigs.T, np.linalg.inv(scipy.linalg.sqrtm(cluster_0_cov))))
			svd_matrix.append(eigs)
		var_module_path = os.path.join(args.var_module_path, f'{args.target_layer}_{args.dimension}_projection_matrix.npy')
		svd_module_path = os.path.join(args.var_module_path, f'{args.target_layer}_{args.dimension}_svd_matrix.npy')
		np.save(var_module_path, projection_matrix)
		np.save(svd_module_path, svd_matrix)
		logging.info(f"Projection matrix saved to {var_module_path}")
		logging.info(f"SVD matrix saved to {svd_module_path}")
				
		true_index = np.zeros(len(train_dataset))
		for i in range(len(true_index)):
			if i in pindex:
				true_index[i] = 1
		
		
		rindex = np.ones(len(train_dataset))
		if len(suspicious_indices) != 0:
			for i in suspicious_indices:
				rindex[i] = 0
		### find rindex==1 index
		rindex_loc = np.where(rindex==1)[0]
		### save the rindex_loc
		np.save(self.args.save_path + '/rindex_loc.npy', rindex_loc)
		
		
		if len(suspicious_indices)==0:
			tn = len(true_index) - np.sum(true_index)
			fp = np.sum(true_index)
			fn = 0
			tp = 0
			TPR, FPR = 0, 0
			f = open(self.args.save_path + '/detection_info.csv', 'a', encoding='utf-8')
			csv_write = csv.writer(f)
			csv_write.writerow(['record', 'TN','FP','FN','TP','TPR','FPR', 'target'])
			csv_write.writerow([args.result_file, tn,fp,fn,tp, 0,0, 'None'])
			f.close()
		else: 
			logging.info("Flagged label list: {}".format(",".join(["{}: {}".format(y_label, s) for y_label, s in flag_list])))
			findex = np.zeros(len(train_dataset))
			for i in range(len(findex)):
				if i in suspicious_indices:
					findex[i] = 1
			if np.sum(findex) == 0:
				tn = len(true_index) - np.sum(true_index)
				fp = np.sum(true_index)
				fn = 0
				tp = 0
			else:
				tn, fp, fn, tp = self.cal(true_index, findex)
			TPR, FPR, precision, acc = self.metrix(tn, fp, fn, tp)

			new_TP = tp
			new_FN = fn*9
			new_FP = fp*1
			precision = new_TP / (new_TP + new_FP) if new_TP + new_FP != 0 else 0
			recall = new_TP / (new_TP + new_FN) if new_TP + new_FN != 0 else 0
			fw1 = 2*(precision * recall)/ (precision + recall) if precision + recall != 0 else 0
			end = time.perf_counter()
			time_miniute = (end-start)/60

			f = open(self.args.save_path + '/detection_info.csv', 'a', encoding='utf-8')
			csv_write = csv.writer(f)
			csv_write.writerow(['record', 'TN','FP','FN','TP','TPR','FPR', 'target'])
			csv_write.writerow([args.result_file, tn, fp, fn, tp, TPR, FPR, [i for i,j in flag_list]])
			f.close()
		logging.info("TPR: {}, FPR: {}".format(TPR, FPR))

	def detection(self,result_file):
		self.set_result(result_file)
		self.set_logger()
		result = self.filtering()


if __name__ == '__main__':
	parser = argparse.ArgumentParser(description=sys.argv[0])
	samde.add_arguments(parser)
	args = parser.parse_args()
	samde_method = samde(args)
	if "result_file" not in args.__dict__:
		args.result_file = 'defense_test_badnet'
	elif args.result_file is None:
		args.result_file = 'defense_test_badnet'
	result = samde_method.detection(args.result_file)
