import argparse
import os,sys
import numpy as np
import torch
import torch.nn as nn
sys.path.append('../')
sys.path.append(os.getcwd())

from pprint import  pformat
import yaml
import logging
import time
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from defense.base import defense
import scipy
from utils.aggregate_block.train_settings_generate import argparser_criterion, argparser_opt_scheduler
from utils.trainer_cls import PureCleanModelTrainer
from utils.aggregate_block.fix_random import fix_random
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.log_assist import get_git_info
from utils.aggregate_block.dataset_and_transform_generate import get_input_shape, get_num_classes, get_transform
from utils.save_load_attack import load_attack_result, save_defense_result
from utils.nCHW_nHWC import *

import tqdm
import heapq
from PIL import Image
from utils.bd_dataset_v2 import dataset_wrapper_with_transform,xy_iter, prepro_cls_DatasetBD_v2
from utils.trainer_cls import Metric_Aggregator, PureCleanModelTrainer, all_acc, general_plot_for_epoch, given_dataloader_test
from collections import Counter
import copy
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import csv
from sklearn import metrics

import math

def estimate_variance(args, feats_clean, feats_inspection, class_indices_clean, class_indices_inspection, pindex, save_path):
	projection_matrix = []
	for target_class in range(args.num_classes):
		iter = 0
		feature_class_inspection_loc = np.where(class_indices_inspection == target_class)[0]
		if len(feature_class_inspection_loc) == 0:
			project_matrix.append(np.zeros((feats_inspection.shape[1], args.dimension)))
			continue
		feature_class_inspection = feats_inspection[class_indices_inspection == target_class]
		feature_class_clean = feats_clean[class_indices_clean == target_class]

		pindex_in_inspection_all = len(np.intersect1d(feature_class_inspection_loc, pindex))
		pindex_not_in_inspection_all = len(np.setdiff1d(feature_class_inspection_loc, pindex))
		## svd to get the top 7 eigenvectors
		if args.reduce_dimension:
			feature_all = np.concatenate((feature_class_inspection, feature_class_clean), axis=0)
			mean = np.mean(feature_all, axis=0)
			feature_all -= mean
			## svd
			U, S, V = np.linalg.svd(feature_all, full_matrices=False)
			## get the top 7 eigenvectors
			dim = args.dimension
			eigs = V[0:dim]
			feature_class_inspection = np.dot(feature_class_inspection, eigs.T)
			feature_class_clean = np.dot(feature_class_clean, eigs.T)
		else:
			dim = feature_class_inspection.shape[1]
			eigs = np.eye(dim)
			
		cluster_0_indices = []
		cluster_1_indices = []
		
		epsilon = args.epsilon
		while iter < 10:
			
			# Initialize the cluster centers
			cluster_0_center = np.mean(feature_class_clean, axis=0)
			cluster_0_cov = np.cov(feature_class_clean, rowvar=False)
			# calculate the Mahalanobis distance
			mahalanobis_distance = scipy.spatial.distance.cdist(feature_class_inspection, [cluster_0_center], 'mahalanobis', VI=np.linalg.pinv(cluster_0_cov))
			# update the cluster center
			if iter == 0:
				#distances = scipy.spatial.distance.cdist(feature_class_inspection, [cluster_0_center], 'euclidean')
				#distances_percentile = np.percentile(distances, 5)
				#cluster_0_indices = np.where(distances <= distances_percentile)[0]
				mahalanobis_distance_percentile = np.percentile(mahalanobis_distance, 10)
				cluster_0_indices = np.where(mahalanobis_distance <= mahalanobis_distance_percentile)[0]
			else:
				epsilon_t = epsilon/10*(iter+1)
				cluster_0_indices = np.where(mahalanobis_distance <= epsilon_t)[0]
			feature_class_clean = np.concatenate((feature_class_clean, feature_class_inspection[cluster_0_indices]), axis=0)
			feature_class_inspection = np.delete(feature_class_inspection, cluster_0_indices, axis=0)
			feature_class_inspection_loc = np.delete(feature_class_inspection_loc, cluster_0_indices, axis=0)
			# calculate how many feature_class_inspection_loc in the pindex and how many is not in the pindex
			pindex_in_inspection = len(np.intersect1d(feature_class_inspection_loc, pindex))
			pindex_not_in_inspection = len(np.setdiff1d(feature_class_inspection_loc, pindex))
			# calculate the tpr and fpr
			tpr = pindex_in_inspection / pindex_in_inspection_all if pindex_in_inspection_all != 0 else 0
			fpr = pindex_not_in_inspection / pindex_not_in_inspection_all
			logging.info("target class: {}, iter: {}, tpr: {}, fpr: {}".format(target_class, iter, tpr, fpr))
			iter += 1

		if len(feature_class_inspection) != 0:
			suspicious_indices.extend(feature_class_inspection_loc)	
			logging.info("target class: {}, suspicious indices: {}".format(target_class, len(feature_class_inspection_loc)))
		### calculate the projection matrix for each class: eigenvectors * sigma^(-1/2) 
		sigma = cluster_0_cov
		projection_matrix.append(np.dot(eigs.T, np.linalg.inv(scipy.linalg.sqrtm(sigma))))
	save_path = os.path.join(save_path, f'{args.target_layer}_{args.dimension}_projection_matrix.npy')
	np.save(save_path, projection_matrix)
	return projection_matrix

def get_features_clean(name, model, dataloader, target_layer, args=None):
	with torch.no_grad():
		model.eval()
		TOO_SMALL_ACTIVATIONS = 32
	activations_all = []
	labels_all = []
	for i, (x_batch, y_batch, *flags) in enumerate(dataloader):
		#assert name in ['preactresnet18', 'vgg19','vgg19_bn', 'resnet18', 'mobilenet_v3_large', 'mobilener_v2' 'densenet161', 'efficientnet_b3','convnext_tiny','vit_b_16','resnet18_cifar']
		x_batch = x_batch.to(args.device)
		labels_all.extend(y_batch)
		if name in ['preactresnet18', 'vgg19','vgg19_bn', 'resnet18', 'resnet18_cifar', 'mobilenet_v3_large', 'efficientnet_b3', 'convnext_tiny', 'vit_b_16']:
			inps,outs = [],[]
			def layer_hook(module, inp, out):
				outs.append(out.data)
			hook = target_layer.register_forward_hook(layer_hook)
			_ = model(x_batch)
			activations = outs[0].view(outs[0].size(0), -1)
			activations_all.append(activations.cpu())
			hook.remove()
		elif name == 'mobilenet_v2':
			inps,outs = [],[]
			def layer_hook(module, inp, out):
				outs.append(out.data)
			hook = model.avgpool.register_forward_hook(layer_hook)
			_ = model(x_batch)
			outs[0] = torch.nn.functional.adaptive_avg_pool2d(out[0], (1, 1))
			activations = outs[0].view(outs[0].size(0), -1)
			activations_all.append(activations.cpu())
			hook.remove()
		elif name == 'densenet161':
			inps,outs = [],[]
			def layer_hook(module, inp, out):
				outs.append(out.data)
			hook = target_layer.register_forward_hook(layer_hook)
			_ = model(x_batch)
			outs[0] = torch.nn.functional.relu(outs[0])
			outs[0] = torch.nn.functional.adaptive_avg_pool2d(outs[0], (1, 1))
			activations = outs[0].view(outs[0].size(0), -1)
			activations_all.append(activations.cpu())
			hook.remove()
		

	activations_all = torch.cat(activations_all, axis=0)
	activations_all = np.array(activations_all)
	labes_all = np.array(labels_all)
	return activations_all, labels_all

def get_features(name, model, dataloader, target_layer, args, variance=False, clean_dataloader=None):
	with torch.no_grad():
		model.eval()
		TOO_SMALL_ACTIVATIONS = 32
	activations_all = []
	labels_all = []
	pindex = []
	for i, (x_batch, y_batch, *flags) in enumerate(dataloader):
		#assert name in ['preactresnet18', 'vgg19','vgg19_bn', 'resnet18', 'mobilenet_v3_large', 'mobilener_v2' 'densenet161', 'efficientnet_b3','convnext_tiny','vit_b_16','resnet18_cifar']
		x_batch = x_batch.to(args.device)
		labels_all.extend(y_batch)
		if flags:
			pindex.extend(flags[1])
		else:
			pindex.extend([0]*len(y_batch))
		if name in ['preactresnet18', 'vgg19','vgg19_bn', 'resnet18', 'resnet18_cifar', 'mobilenet_v3_large', 'efficientnet_b3', 'convnext_tiny', 'vit_b_16', 'resnet18_ctrl']:
			inps,outs = [],[]
			def layer_hook(module, inp, out):
				outs.append(out.data)
			hook = target_layer.register_forward_hook(layer_hook)
			_ = model(x_batch)
			activations = outs[0].view(outs[0].size(0), -1)
			activations_all.append(activations.cpu())
			hook.remove()
		elif name == 'mobilenet_v2':
			inps,outs = [],[]
			def layer_hook(module, inp, out):
				outs.append(out.data)
			hook = model.avgpool.register_forward_hook(layer_hook)
			_ = model(x_batch)
			outs[0] = torch.nn.functional.adaptive_avg_pool2d(out[0], (1, 1))
			activations = outs[0].view(outs[0].size(0), -1)
			activations_all.append(activations.cpu())
			hook.remove()
		elif name == 'densenet161':
			inps,outs = [],[]
			def layer_hook(module, inp, out):
				outs.append(out.data)
			hook = target_layer.register_forward_hook(layer_hook)
			_ = model(x_batch)
			outs[0] = torch.nn.functional.relu(outs[0])
			outs[0] = torch.nn.functional.adaptive_avg_pool2d(outs[0], (1, 1))
			activations = outs[0].view(outs[0].size(0), -1)
			activations_all.append(activations.cpu())
			hook.remove()
		
	activations_all = torch.cat(activations_all, axis=0)
	if variance:
		activations_all = np.array(activations_all)
		labes_all = np.array(labels_all)
		projection_matrix_loc = os.path.join(args.var_module_path, f'{args.target_layer}_{args.dimension}_projection_matrix.npy')
		if os.path.exists(projection_matrix_loc):
			projection_matrix = np.load(projection_matrix_loc)
		else:
			activations_clean, labels_clean = get_features_clean(name, model, clean_dataloader, target_layer, args)
			projection_matrix = estimate_variance(args, activtions_clean, activations_all, labels_clean, labels_all, pindex, args.save_path)
		## for each class, calculate the project array
		N = len(activations_all)
		D = activations_all.shape[1]
		C = args.dimension
		project_matrix = np.zeros((N, D, C))

		# 基于labels_all填充project_matrix
		for i in range(N):
			j = labels_all[i]
			#labels_set = sorted(set(labels_all))
			#j = labels_set.index(labels_all[i])
			project_matrix[i, :, :] = projection_matrix[j]
		result_matrix = np.einsum('nd,ndc->nc', activations_all, project_matrix)
		activations_all = torch.tensor(result_matrix)
	return activations_all


