import torch
import json
from transformers import CLIPProcessor, CLIPModel
import torch
import numpy as np
import pandas as pd
import utils
import copy
from sklearn.decomposition import TruncatedSVD
from sklearn.model_selection import train_test_split
import numpy as np
import random
import argparse
from scipy.spatial.distance import cosine
from scipy.optimize import minimize
from sklearn.ensemble import RandomForestClassifier


parser = argparse.ArgumentParser(description="Run Fairness Evaluation Script")
parser.add_argument("--model_ID", type=str, required=True, help="Identifier for the model")
parser.add_argument("--query_type", type=str, required=True, help="Type of query to evaluate")
parser.add_argument("--eval_concept", type=str, required=True, help="Concept used for evaluation (e.g., gender)")
parser.add_argument("--debias_concept", type=str, required=True, help="Concept used for debiasing (e.g., gender)")

args = parser.parse_args()


seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

#### Debiasing Functions ##########################################################
def get_proj_matrix(embeddings, subspace_dim = None):
	if subspace_dim == None:
		subspace_dim = len(embeddings)
	tSVD = TruncatedSVD(n_components=subspace_dim)
	embeddings_ = tSVD.fit_transform(embeddings)
	basis = tSVD.components_.T

	# orthogonal projection
	proj = np.linalg.inv(np.matmul(basis.T, basis))
	proj = np.matmul(basis, proj)
	proj = np.matmul(proj, basis.T)
	proj = np.eye(proj.shape[0]) - proj
	return proj

def get_img_df(_path):
	def converter(instr):
		return np.fromstring(instr[1:-1],sep=' ')
	img_df = pd.read_csv(_path,converters={'img_embed':converter})
	return img_df

def demo_parity(_img_df, top_k, _concept, cls):
	top_k_g_over_k = (_img_df.iloc[top_k][_concept] == cls).mean()
	all_g_over_n = (_img_df[_concept] == cls).mean()

	# print(top_k_g_over_k)
	dp_g = np.abs(top_k_g_over_k - all_g_over_n)

	return dp_g

def worst_demo_parity(_img_df, top_k, _concept, _class_dict):
	dp_worst = -1000000
	disadvantaged_group = None
	for cls in _class_dict[_concept]:
		dp_cls = demo_parity(_img_df, top_k, _concept, cls)
		if dp_cls > dp_worst:
			dp_worst = dp_cls
			disadvantaged_group = cls
	return dp_worst, disadvantaged_group

def disparate_impact(_img_df, top_k, _concept, disadv_cls, adv_cls):
	p_disadv = (_img_df.iloc[top_k][_concept] == disadv_cls).mean()
	p_adv = (_img_df.iloc[top_k][_concept] == adv_cls).mean()
	return p_disadv / p_adv

def worst_disparate_impact(_img_df, top_k, _concept, _class_dict):
	p_adv = -100000
	p_disadv = 10000

	advantaged_group = None
	disadvantaged_group = None
	for cls in _class_dict[_concept]:
		p_cls = (_img_df.iloc[top_k][_concept] == cls).mean()

		if p_cls > p_adv:
			p_adv = p_cls
			advantaged_group = cls 
	
		if p_cls < p_disadv:
			p_disadv = p_cls
			disadvantaged_group = cls 

	di = disparate_impact(_img_df, top_k, _concept, disadvantaged_group, advantaged_group)
	return di, disadvantaged_group, advantaged_group



def apply_proj_debiasing(embed_array, P0, norm=True):
	proj_debiased = np.matmul(embed_array, P0.T).float()
	if norm:
		proj_debiased = proj_debiased / proj_debiased.norm(dim=1, keepdim=True)
	return proj_debiased


def rotate_projection_equal_inner_product(embed_array, P0, B):
	if 'torch' in str(type(embed_array)):
		embed_array = embed_array.detach().cpu().numpy()
	if 'torch' in str(type(B)):
		B = B.detach().cpu().numpy()

	(V_perp_list, V_s_list, V_s_new_list, V_rotated_list) = ([], [], [], []) 

	for i in range(embed_array.shape[0]):
		V = embed_array[i]
		
		G = B.T @ B
		v_coeffs = np.linalg.solve(G, B.T @ V)
		V_s = B @ v_coeffs
		V_perp = V - V_s

		u = B @ np.linalg.solve(G, np.ones(B.shape[1]))
		u = u / np.linalg.norm(u)
		V_s_new = np.linalg.norm(V_s) * u

		V_rotated = V_perp + V_s_new

		V_perp_list.append(V_perp.reshape(1,-1))
		V_s_list.append(V_s.reshape(1,-1))
		V_s_new_list.append(V_s_new.reshape(1,-1))
		V_rotated_list.append(V_rotated.reshape(1,-1))

	V_perp_array = np.concatenate(V_perp_list)
	V_s_array = np.concatenate(V_s_list)
	V_s_new_array = np.concatenate(V_s_new_list)
	V_rotated_array = np.concatenate(V_rotated_list)
	
	return V_perp_array, V_s_array, V_s_new_array, V_rotated_array, 

def apply_new_wring_debiasing(embed_array, P0, U):
	V_perp_array, V_s_array, V_s_new_array, V_rotated_array = rotate_projection_equal_inner_product(embed_array, P0, U)
	return V_rotated_array


# 4. Define a function to get top N important features
def get_top_n_features(rf, feature_names, n=10):
	importances = rf.feature_importances_
	feature_importance_df = pd.DataFrame({
		'feature': feature_names,
		'importance': importances
	})
	feature_importance_df = feature_importance_df.sort_values(by='importance', ascending=False)
	return feature_importance_df.head(n)

def set_top_features_to_constant(X, top_features_df, feature_names, c=0):
	X_modified = X.copy()
	# Get indices of the top features
	top_feature_indices = [feature_names.index(f) for f in top_features_df['feature']]
	
	# Set those columns to the constant value
	X_modified[:, top_feature_indices] = c
	return X_modified

# Function to get the indices of low-confidence samples
def get_low_confidence_indices(rf, X, fraction=0.2):
	# Predict probabilities
	probs = rf.predict_proba(X)
	
	# Confidence = max probability per sample
	confidences = np.max(probs, axis=1)
	
	# Sort samples by lowest confidence
	n_low = int(len(confidences) * fraction)
	low_confidence_indices = np.argsort(confidences)[:n_low]
	
	return low_confidence_indices


# Function to set top features to the average over low-confidence samples
def set_top_features_to_low_confidence_mean(X, rf, top_features_df, feature_names, fraction=0.2):
	X_modified = X.copy()
	
	# Step 1: Find low-confidence samples
	low_conf_indices = get_low_confidence_indices(rf, X, fraction)
	
	# Step 2: Get top feature indices
	top_feature_indices = [feature_names.index(f) for f in top_features_df['feature']]
	
	# Step 3: Compute mean value for each top feature over low-confidence samples
	means = np.mean(X[low_conf_indices][:, top_feature_indices], axis=0)
	
	# Step 4: Set ALL samples' top features to these means
	for idx, feature_idx in enumerate(top_feature_indices):
		X_modified[:, feature_idx] = means[idx]
	
	return X_modified

def compute_low_confidence_feature_means(rf, X, top_features_df, feature_names, fraction=0.2):
	probs = rf.predict_proba(X)
	confidences = np.max(probs, axis=1)
	n_low = max(1, int(len(confidences) * fraction))
	low_confidence_indices = np.argsort(confidences)[:n_low]
	
	top_feature_indices = [feature_names.index(f) for f in top_features_df['feature']]
	means = np.mean(X[low_confidence_indices][:, top_feature_indices], axis=0)
	
	return means, top_feature_indices

def apply_feature_means(X, top_feature_indices, means):
	X_modified = X.copy()
	for idx, feature_idx in enumerate(top_feature_indices):
		X_modified[:, feature_idx] = means[idx]
	return X_modified
#########################################################



####Define the model ###################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model_ID = args.model_ID 
model = CLIPModel.from_pretrained(model_ID)
processor = CLIPProcessor.from_pretrained(model_ID)
wrapped_clip = utils.CLIPTextWrapper(model_ID, device=device)

#######################################


#### load in the precomputed image embeddings and labels ##########
img_df = get_img_df(f"spawrious_2/{model_ID.split('/')[-1]}.csv")
img_embeds = torch.tensor(np.concatenate([img_df['img_embed'].values[i].reshape(1,-1) 
                             for i in range(img_df.shape[0])])).float()
img_embeds = img_embeds / img_embeds.norm(dim=1, keepdim=True)
####################################################################

######## Define the classes for debiasing and evaluation ##########
#Change this part to reflect the concepts in whichever dataset you're using
class_dict = {}
class_dict['breed'] = ["bulldog", "corgi", "dachshund", "labrador"]
class_dict['background'] = ["jungle", "mountain", "snow", "desert"]
#Change this part to reflect the concepts in whichever dataset you're using

####################################################################


####### Define the directions for text debiasing #######

class_phrases = {}
class_phrases['breed'] = {}
class_phrases['background'] = {}

for breed in class_dict['breed'] :
    class_phrases['breed'][f"{breed}_phrases"] = [f'a photo of a {breed}', f"an image of a {breed}", f"a {breed}", 
                    f'a picture of a {breed}']

class_phrases['breed']['neutral_phrases'] = [f'a photo of a dog', f"an image of a dog", f"a dog", 
                    f'a picture of a dog']

class_phrases['background'] = {}

for background in class_dict['background']:
    class_phrases['background'][f"{background}_phrases"] = [f'a photo of a dog with a {background} background', f"an image of a dog with a {background} background", f"a dog with a {background} background", 
                    f'a picture of a dog with a {background} background']  
    
class_phrases['background']['neutral_phrases'] = [f'a picture of a dog', f"an image of a dog", f"a picture of a dog outside", 'a photo of a dog']  


total_neutral_phrases = ['a photo of a dog', "an image of a dog", "a dog", "a dog", 
                         'a photo of my dog', "this dog", "my dog", "a puppy"]

####################################################################


##### Get P0 and U for text directions ######

class_embeddings = {}

for _class in class_dict.keys():
	class_embeddings[f"{_class}_embeddings"] = {}
	for cls in class_dict[_class]:
		class_embeddings[f"{_class}_embeddings"][f"{cls}_embeddings"] = wrapped_clip.get_joint_embed(class_phrases[_class][f"{cls}_phrases"])
	
	class_embeddings[f"{_class}_embeddings"]['neutral_embeddings'] = wrapped_clip.get_joint_embed(class_phrases[_class]["neutral_phrases"])

total_neutral_embeddings = wrapped_clip.get_joint_embed(total_neutral_phrases)



direction_dict = {}


for _class in class_dict.keys():
	direction_dict[f"{_class}_dirs"] = {}
	for cls in class_dict[_class]:
		direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] = class_embeddings[f"{_class}_embeddings"][f"{cls}_embeddings"].mean(dim=0, keepdim=True)
	direction_dict[f"{_class}_dirs"]['neutral_dirs'] = class_embeddings[f"{_class}_embeddings"][f"neutral_embeddings"].mean(dim=0, keepdim=True)

subtract_neutral = False

if subtract_neutral:
	for _class in class_dict.keys():
		for cls in class_dict[_class]:
			direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] = direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] - direction_dict[f"{_class}_dirs"]['neutral_dirs']

make_perp_to_neutral = True

if make_perp_to_neutral:


	_total_neutral_embeddings = total_neutral_embeddings

	proj_matrix = _total_neutral_embeddings.T @ np.linalg.pinv(_total_neutral_embeddings.T)
	proj_matrix = proj_matrix
	for _class in class_dict.keys():
		for cls in class_dict[_class]:
			direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] = direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] - direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] @ proj_matrix


for _class in class_dict.keys():
	for cls in class_dict[_class]:
		direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] = direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] / direction_dict[f"{_class}_dirs"][f"{cls}_dirs"].norm(dim=1, keepdim=True)


P0_dict = {}

for _class in class_dict.keys():
	_class_dirs = np.concatenate([direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] for cls in class_dict[_class]])
	P0_dict[f"P0_{_class}"] = get_proj_matrix(copy.deepcopy(_class_dirs), subspace_dim = None)


U_dict = {}

for _class in class_dict.keys():
	_class_dirs = np.concatenate([direction_dict[f"{_class}_dirs"][f"{cls}_dirs"] for cls in class_dict[_class]])
	U_dict[f"U_{_class}"] = _class_dirs.T

####################################################################


### Define the querries 
query_type = args.query_type 

def get_search_queries(query_type = "dangerous"):

	if query_type == 'dangerous':
		query_terms = ['threatening a child', "attacking someone", "lunging forward"]
		
	elif query_type == 'protective':
		query_terms = ['protecting a child', "protecting someone", "guarding"]

	else:
		query_terms = ['playing with a child', "playing fetch", "wagging tail"]

	return query_terms


query_terms = get_search_queries(query_type)
print(query_terms)

####################################################################

########### Function for collecting results
def skew_score(_img_df, top_k, _concept, cls):
	top_k_g_over_k = (_img_df.iloc[top_k][_concept] == cls).mean()
	all_g_over_n = (_img_df[_concept] == cls).mean()

	# print(top_k_g_over_k)
	ms_g = np.log(top_k_g_over_k) - np.log(all_g_over_n)

	return ms_g

def max_skew_score(_img_df, top_k, _concept, _class_dict):
	ms_worst = -1000000
	disadvantaged_group = None
	for cls in _class_dict[_concept]:
		ms_cls = skew_score(_img_df, top_k, _concept, cls)
		if ms_cls > ms_worst:
			ms_worst = ms_cls
			disadvantaged_group = cls
	return ms_worst, disadvantaged_group

def get_metrics(sample_df, baseline_top_k, _concept, class_dict, proj_top_k, new_wring_top_k, unified_deb_top_k):
	metrics = {}
	
	# Demographic Parity Worst
	metrics["baseline_dp_worst"], metrics["baseline_dp_disadvantaged_group"] = worst_demo_parity(
		sample_df, baseline_top_k, _concept, class_dict
	)

	metrics["proj_dp_worst"], metrics["proj_dp_disadvantaged_group"] = worst_demo_parity(
		sample_df, proj_top_k, _concept, class_dict
	)
	
	metrics["new_wring_dp_worst"], metrics["new_wring_dp_disadvantaged_group"] = worst_demo_parity(
		sample_df, new_wring_top_k, _concept, class_dict
	)

	metrics["unified_deb_dp_worst"], metrics["unified_deb_dp_disadvantaged_group"] = worst_demo_parity(
		sample_df, unified_deb_top_k, _concept, class_dict
	)


	# Max Skew
	metrics["baseline_max_skew"], metrics["baseline_ms_disadvantaged_group"] = max_skew_score(
		sample_df, baseline_top_k, _concept, class_dict
	)

	metrics["proj_max_skew"], metrics["proj_ms_disadvantaged_group"] = max_skew_score(
		sample_df, proj_top_k, _concept, class_dict
	)
	
	metrics["new_wring_max_skew"], metrics["new_wring_ms_disadvantaged_group"] = max_skew_score(
		sample_df, new_wring_top_k, _concept, class_dict
	)

	metrics["unified_deb_max_skew"], metrics["unified_deb_ms_disadvantaged_group"] = max_skew_score(
		sample_df, unified_deb_top_k, _concept, class_dict
	)



	#######################

	# Disparate Impact Worst (with advantaged group)
	(
		metrics["baseline_di_worst"],
		metrics["baseline_di_disadvantaged_group"],
		metrics["baseline_di_advantaged_group"]
	) = worst_disparate_impact(sample_df, baseline_top_k, _concept, class_dict)

	(
		metrics["proj_di_worst"],
		metrics["proj_di_disadvantaged_group"],
		metrics["proj_di_advantaged_group"]
	) = worst_disparate_impact(sample_df, proj_top_k, _concept, class_dict)

	(
		metrics["new_wring_di_worst"],
		metrics["new_wring_di_disadvantaged_group"],
		metrics["new_wring_di_advantaged_group"]
	) = worst_disparate_impact(sample_df, new_wring_top_k, _concept, class_dict)

	(
		metrics["unified_deb_di_worst"],
		metrics["unified_deb_di_disadvantaged_group"],
		metrics["unified_deb_di_advantaged_group"]
	) = worst_disparate_impact(sample_df, unified_deb_top_k, _concept, class_dict)


	return metrics


####################################################################



###### Run the retrieval experiment 
debias_concept = args.debias_concept 
eval_concept = args.eval_concept 
k = 1000

num_bootstrap = 100

result_dict = {}

for b in range(num_bootstrap):
	print(f"bootstrap {b}/{num_bootstrap}")

	result_dict[b] = {}

	####

	sample_df = img_df.iloc[np.random.choice(img_df.index, size=len(img_df), replace=True)]
	sample_df, ref_df = train_test_split(sample_df, test_size=0.2) 
	
	img_embeds_sample = img_embeds[list(sample_df.index)]
	img_embeds_ref = img_embeds[list(ref_df.index)]

	sample_df = sample_df.reset_index(drop=True)
	ref_df = ref_df.reset_index(drop=True)


	proj_debiased_query = {}
	new_wring_debiased_query = {}


	###here

	X, y = img_embeds_ref.to('cpu').numpy(), ref_df[debias_concept]

	# 2. Create feature names (just for clarity)
	feature_names = [i for i in range(X.shape[1])]

	# 3. Train a Random Forest Classifier
	rf = RandomForestClassifier(n_estimators=10, random_state=42)
	rf.fit(X, y)

	top_features = get_top_n_features(rf, feature_names, n=20)

	means, top_feature_indices = compute_low_confidence_feature_means(
		rf, X, top_features, feature_names, fraction=0.1
	)

	U_sfid = np.zeros((X.shape[1], len(top_features)))
	for f_indx, f in enumerate(top_features['feature']):
		U_sfid[f, f_indx] = 1


	for q_indx, q in enumerate(query_terms):
		print(q)
		result_dict[b][q] = {}


		#####################################


		center_dict = {}
		center_dict[debias_concept] = {}

		center_proj_dict = {}
		center_proj_dict[debias_concept] = {}

		for cls in class_dict[debias_concept]:
			
			idx = np.asarray((ref_df[debias_concept] == cls).index)

			center_dict[debias_concept][cls] =\
					img_embeds_ref[[idx[(ref_df[debias_concept] == cls)]]].mean(dim=0, keepdim=True)
			
			########
			
			
			idx_proj = np.asarray((ref_df[debias_concept] == cls).index)

			center_proj_dict[debias_concept][cls] =\
					img_embeds_ref[[idx_proj[(ref_df[debias_concept] == cls)]]].mean(dim=0, keepdim=True)
			

		im_U = np.concatenate([center_dict[debias_concept][c] for c in class_dict[debias_concept]]).T

		im_P0 = get_proj_matrix(np.concatenate([center_proj_dict[debias_concept][c] for c in class_dict[debias_concept]]), 
								subspace_dim = None)


		###########################################


		query_embeds = wrapped_clip.get_joint_embed([q])

		proj_debiased_query[f"proj_{debias_concept}_{q}_embeds"] = apply_proj_debiasing(query_embeds, 
																						P0_dict[f"P0_{debias_concept}"])


		new_wring_debiased_query[f"new_wring_{debias_concept}_{q}_embeds"] = apply_new_wring_debiasing(query_embeds, P0_dict[f"P0_{debias_concept}"], U_dict[f"U_{debias_concept}"])
			
		
		unified_deb_query = apply_feature_means(query_embeds.numpy(), top_feature_indices, means)
		unified_deb_img_embeds = apply_feature_means(img_embeds_sample.to('cpu').numpy(), top_feature_indices, means)


		###Image Dirs

		proj_debiased_query[f"proj_im_{debias_concept}_{q}_embeds"] = apply_proj_debiasing(query_embeds, im_P0)

		new_wring_debiased_query[f"new_wring_im_{debias_concept}_{q}_embeds"] = apply_new_wring_debiasing(query_embeds, im_P0, im_U)

		baseline_similarities = (img_embeds_sample @ query_embeds.T).squeeze()
		baseline_top_k = np.argsort(baseline_similarities)[-k:].numpy()

		new_wring_similarities = (img_embeds_sample @ new_wring_debiased_query[f"new_wring_{debias_concept}_{q}_embeds"].T).squeeze()
		new_wring_top_k = np.argsort(new_wring_similarities)[-k:].numpy()


		proj_similarities = (img_embeds_sample @ proj_debiased_query[f"proj_{debias_concept}_{q}_embeds"].T).squeeze()
		proj_top_k = np.argsort(proj_similarities)[-k:].numpy()

		unified_deb_similarities = (unified_deb_img_embeds @ unified_deb_query.T).squeeze()
		unified_deb_top_k = np.argsort(unified_deb_similarities)[-k:]#.numpy()

		###IM


		new_wring_im_similarities = (img_embeds_sample @ new_wring_debiased_query[f"new_wring_im_{debias_concept}_{q}_embeds"].T).squeeze()
		new_wring_im_top_k = np.argsort(new_wring_im_similarities)[-k:].numpy()

		proj_im_similarities = (img_embeds_sample @ proj_debiased_query[f"proj_im_{debias_concept}_{q}_embeds"].T).squeeze()
		proj_im_top_k = np.argsort(proj_im_similarities)[-k:].numpy()


		####


		metrics_txt_eval = get_metrics(sample_df, baseline_top_k, eval_concept, class_dict, proj_top_k, new_wring_top_k, unified_deb_top_k)
		metrics_txt_debias = get_metrics(sample_df, baseline_top_k, debias_concept, class_dict, proj_top_k, new_wring_top_k, unified_deb_top_k)

		metrics_im_eval = get_metrics(sample_df, baseline_top_k, eval_concept, class_dict, proj_im_top_k, new_wring_im_top_k, unified_deb_top_k)
		metrics_im_debias = get_metrics(sample_df, baseline_top_k, debias_concept, class_dict, proj_im_top_k, new_wring_im_top_k, unified_deb_top_k)



		result_dict[b][q]['metrics_txt_eval'] = metrics_txt_eval
		result_dict[b][q]['metrics_txt_debias'] = metrics_txt_debias

		result_dict[b][q]['metrics_im_eval'] = metrics_im_eval
		result_dict[b][q]['metrics_im_debias'] = metrics_im_debias

####################################################################

# save results
with open(f"results/spawrious_{model_ID.split('/')[-1]}_{query_type}_eval_{eval_concept}_debias_{debias_concept}_results.json", "w") as f:
	json.dump(result_dict, f, indent=4)
		
