#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 12 11:30:26 2023

@author: kwesi
"""

# Numpy
import os
import torch
import numpy as np
import itertools
from sklearn.decomposition import PCA

import warnings
warnings.filterwarnings('ignore')

# import gluonnlp as nlp

###################
target = open("multi-debiasing.txt", 'w')
###################

###########################
# Load/Read Word Embedding
###########################

########
def load_embed(filename, max_vocab=-1):
    words, embeds = [], []
    with open(filename, 'r') as f:
        # next(f)
        for line in f:
            word, vector = line.rstrip().split(' ', 1)
            vector = np.fromstring(vector, sep=' ')
            words.append(word)
            embeds.append(vector)
            if len(embeds) == max_vocab:
                break
    return words, np.array(embeds)

########
def saveEmbed(path, words, word_embeds):
    with open(path, 'w') as f:
        # print(word_embeds.shape[0], word_embeds.shape[1], file=f)
        for word, embed in zip(words, word_embeds):
            vector_str = ' '.join(str(x) for x in embed)
            print(word, vector_str, file = f)
               
########
def get(word_vectors, word):
    """
    Get WordVector object for single word
    """
    return word_vectors[word]

########
def get_many(word_vectors, words):
    """
    Get list of WordVector objects for a list of words
    """
    return [word_vectors[word] for word in words]

########
def get_vecs(word_vectors, words):
    """
    Get numpy array of vectors for a given list of words
    """
    return np.vstack([word_vectors[word] for word in words])

########
def vectors(word_vectors):
    return np.vstack([wv for wv in word_vectors.values()])

########
def words(word_vectors):
    return [wv.word for wv in word_vectors.values()]

########
def update_vectors(words, new_vectors, word_vectors):
    for i, word in enumerate(words):
        word_vectors[word] = new_vectors[i]
           
########        
def remove_center(embeddings):
    center = embeddings.mean(axis=0)[np.newaxis, :]
    embeddings -= center
    return center, embeddings

########
# def bias_two_means(vec1, vec2 ):
#     vec1_mean, vec2_mean = np.mean(vec1, axis=0), np.mean(vec2, axis=0)
#     bias_direction = (vec1_mean - vec2_mean) / np.linalg.norm(vec1_mean - vec2_mean)

#     return bias_direction / np.linalg.norm(bias_direction), vec1_mean / np.linalg.norm(vec1_mean), vec2_mean/ np.linalg.norm(vec2_mean)

#For OPL+ICR
def bias_two_means(vecTORs ):
    return  unit_vector(vecTORs.mean(axis=0)[np.newaxis, :])

########
def get_he_she_basis(emb):
    assert(len(emb.shape) == 2)
    he = emb[0]
    she = emb[1]
    basis = (he - she) / np.linalg.norm(he - she)
    return  basis

########
def get_basis(emb):
    assert(len(emb.shape) == 2)
    pca = PCA(n_components=2)
    pca.fit(emb)
    direction_vector = pca.components_[0]
    return direction_vector / np.linalg.norm(direction_vector)

########
def proj(u, a):
    return (np.dot(u, a)) * u

########
########
def gsConstrained(matrix,v1,v2):
    v1 = np.asarray(v1).reshape(-1)
    v2 = np.asarray(v2).reshape(-1)
    u = np.zeros((np.shape(matrix)[0],np.shape(matrix)[1]))
    u[0] = v1
    u[0] = u[0]/np.linalg.norm(u[0])
    u[1] = v2 - proj(u[0],v2)
    u[1] = u[1]/np.linalg.norm(u[1])
    for i in range(0,len(matrix)-2):
        p = 0.0
        for j in range(0,i+2):    
            p = p + proj(u[j],matrix[i])
        u[i+2] = matrix[i] - p
        u[i+2] = u[i+2]/np.linalg.norm(u[i+2])
    return u


 

########
def basis(vec):
    first_component = vec[0]
    second_component = vec[1]
    v2_prime = second_component - first_component * float(np.matmul(first_component, second_component.T))
    v2_prime = v2_prime / np.linalg.norm(v2_prime)
    return v2_prime

########
def proj_new(vec):
    first_component = vec[0]
    second_component = vec[1]
    return first_component * float(np.matmul(first_component, second_component.T))

########
def rotation(v1, v2, x):
    input_vec = x.copy()
    v2P = v2 - proj(v1, v2)
    v2P = v2P / np.linalg.norm(v2P)

    thetaP = np.arccos(np.dot(v1, v2))
    theta = np.abs(thetaP - np.pi / 2)

    x_norm = x / np.linalg.norm(x)
    phi = np.arccos(np.dot(v1 / np.linalg.norm(v1), x_norm))
    d = np.dot(v2P, x_norm)

    if d > 0 and phi < thetaP:
        thetaX = theta * (phi / thetaP)
    elif d > 0 and phi > thetaP:
        thetaX = theta * ((np.pi - phi) / (np.pi - thetaP + 1e-10))
    elif d < 0 and phi >= np.pi - thetaP:
        thetaX = theta * ((np.pi - phi) / thetaP)
    elif d < 0 and phi < np.pi - thetaP:
        thetaX = theta * (phi / (np.pi - thetaP + 1e-10))
    else:
        return input_vec, v1, v2, v2P, x

    R = np.zeros((2, 2))
    R[0][0] = np.cos(thetaX)
    R[0][1] = -np.sin(thetaX)
    R[1][0] = np.sin(thetaX)
    R[1][1] = np.cos(thetaX)

    return input_vec, v1, v2, v2P, np.matmul(R, x)

########        
def correction2d_new(U, v1, v2, x):
    return rotation(v1, v2, x)
   
########
def unit_vector(vector):
    """ Returns the unit vector of the vector.  """
    return vector / np.linalg.norm(vector)

########
def angle_between(v1, v2):
    """ Returns the angle in radians between vectors 'v1' and 'v2'
    """
    v1_u = unit_vector(v1)
    v2_u = unit_vector(v2)
    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))

########
def load_wordList(filename):
    my_file = open(filename, "r",  encoding="ISO-8859-1")
    data = my_file.read()
    data_into_list = data.replace(' ', '').split(",")  
    my_file.close()
    return data_into_list

########
def getKeyFromValue(dict, value):
    return (k for k,v in dict.items() if v == value)
########

def Weat_dotprodResult(embd, allPairs, class_Vocab, flag = True):
   
    dotProdList = {}
    for pair in  allPairs:
##################      
        # target.write("#################################" + "\n")
        w_emb = get_vecs(embd, class_Vocab[pair[0]])
        t_emb = get_vecs(embd, class_Vocab[pair[1]])
        dir1 =  bias_two_means(w_emb).squeeze()
        dir2 =  bias_two_means(t_emb).squeeze()
        # target.write("Dot Product During Training: " + pair[0] + " vs " + pair[1] + " " +  str(np.round(np.dot(dir1, dir2), decimals = 4)) + "\n")
        dotProdList[pair] = np.abs(np.round(np.dot(dir1, dir2), decimals = 4))
##################    
        # print("#################################")
        # print("Dot Product During Training: " + pair[0] + " vs " + pair[1] + " ", np.round(np.dot(dir1, dir2), decimals = 4))
    print()
    target.write(" " + "\n")
    print("Mean of Absolute Dot Product: ", np.round(np.mean(list(dotProdList.values())), decimals = 4))
    print()
    print("Standard Deviation of Absolute Dot Product: ", np.round(np.std(list(dotProdList.values())), decimals = 4))
    print()
    max_val = np.max( list(dotProdList.values()) )
    max_key = list(getKeyFromValue(dotProdList, max_val) )[0]
    print("Class Pair with Max Dot Product: ", str(  max_key  ), " with ", np.round(max_val, decimals = 4), " dot prodct")
    print()
    min_val = np.min( list(dotProdList.values()) )
    min_key = list(getKeyFromValue(dotProdList, min_val) )[0]
    print("Class Pair with Min Dot Product: ", str(  min_key  ), " with ", np.round(min_val, decimals = 4), " dot prodct")
    print()
   
##################    
    print()
    target.write("#################################" + "\n")
    target.write("Mean of Absolute Dot Product: " + str(np.round(np.mean(list(dotProdList.values())), decimals = 4)) + "\n" )
    target.write(" " + "\n")
    target.write("#################################" + "\n")
    target.write("Standard Deviation of Absolute Dot Product: " + str(np.round(np.std(list(dotProdList.values())), decimals = 4)) + "\n" )
    target.write(" " + "\n")
    max_val = np.max( list(dotProdList.values()) )
    max_key = list(getKeyFromValue(dotProdList, max_val) )[0]
    target.write("#################################" + "\n")
    target.write("Class Pair with Max Dot Product: " +  str(  max_key  ) +  " with " + str(np.round(max_val, decimals = 4)) + " dot prodct" + "\n")
    target.write(" " + "\n")
    min_val = np.min( list(dotProdList.values()) )
    min_key = list(getKeyFromValue(dotProdList, min_val) )[0]
    target.write("#################################" + "\n")
    target.write("Class Pair with Min Dot Product: " + str(  min_key  ) + " with " + str(np.round(min_val, decimals = 4)) + " dot prodct" + "\n")
##################
 
    if flag:
        min_val_track = np.min( list(dotProdList.values()) )
        min_key_track = list(getKeyFromValue(dotProdList, min_val) )[0]
       
        return min_key_track, min_val_track


   
def closest_vec_span(v, spanMatrix):
    v/=np.linalg.norm(v)
    spanMatrix[0] = spanMatrix[0]/np.linalg.norm(spanMatrix[0])
    spanMatrix[1] = spanMatrix[1]/np.linalg.norm(spanMatrix[1])
    num_proj_onto = len(spanMatrix)
    closest_vec = 0.0
    for j in range(num_proj_onto):
        closest_vec = closest_vec + proj(spanMatrix[j], v)
    return closest_vec/ np.linalg.norm(closest_vec)


def class_closestSpan(word_vecs, spanMatrix, new_ClassLabels_list_span, ClassVocab):
    dotProd_dict = {}
    for class_ in  new_ClassLabels_list_span:
        close_emb = get_vecs(word_vecs, ClassVocab[class_])
        dotprod = dotProd_closestSpan(close_emb, spanMatrix)
        dotProd_dict[class_] = dotprod
    min_val = np.min( list(dotProd_dict.values()) )
    return list(getKeyFromValue(dotProd_dict, min_val) )[0]

   
def dotProd_closestSpan(embd, spanMatrix):
        v2  = bias_two_means(embd).squeeze()
        v1  =  closest_vec_span(v2, spanMatrix)
        return np.dot(v1, v2)
   
def removeElementList(unwanted_num, fromList):
    return [ele for ele in fromList if ele not in unwanted_num]

###############################################################################
###############################################################################
# OSCaRSpan(base_emb, all_wordsVocab, gender_emb, gender_words, SpanMatrix, gender_words, occupation_words, race_words)

def OSCaRSpan(word_vecs, Vocabs, allPairs, spanMatrix, class_Vocab, closest_class):
   
 
   
    race_emb = get_vecs(word_vecs, class_Vocab[closest_class])
    for iteration in range(1):
 
       
        ######################################################
        #1. Get v1 & v2
        ######################################################
       
        v2  = bias_two_means(race_emb).squeeze()
        v1  =  closest_vec_span(v2, spanMatrix)
 
        ######################################################
        #3. Check for Angle Change v1 and v2
        ######################################################
       
        theta = angle_between(v1, v2)
        if theta > np.pi / 2:
            v2 = -v2  
       
       
        vecs = vectors(word_vecs)
        rot_matrix = gsConstrained(np.identity(v1.shape[0]), v1, v2)  
        proj_newBasis = np.matmul(vecs, rot_matrix.T)
       
        x_coord = proj_newBasis[:, 0]
        y_coord = proj_newBasis[:, 1]
           
   
        ######################################################
        #4. Compute v2_prime
        ######################################################
       
        v2_prime = v2 - v1 * (v2.dot(v1))
        v2_prime = v2_prime / np.linalg.norm(v2_prime)
       
        ######################################################
        # Get Span Matrix
        ######################################################
        if iteration == 0:
            spanMatrix.append(v2_prime)
           
           
        ######################################################        
        #5. Update all K-d points to 2-d
        ######################################################
       
        V1_direction = np.array([v1.dot(v1), v1.dot(v2_prime)])
        V1_direction = V1_direction / np.linalg.norm(V1_direction)
       
        V2_direction = np.array([v2.dot(v1), v2.dot(v2_prime)])
        V2_direction = V2_direction / np.linalg.norm(V2_direction)
       
        word_vecs =  dict(zip(Vocabs, np.vstack([x_coord, y_coord]).T))
       
        bias_direction = np.array([v1.dot(v1), v1.dot(v2_prime)])
        bias_direction = bias_direction / np.linalg.norm(bias_direction)
        orth_direction = np.array([v2.dot(v1), v2.dot(v2_prime)])
        orth_direction = orth_direction / np.linalg.norm(orth_direction)
        orth_direction_prime = np.array([v2_prime.dot(v1), v2_prime.dot(v2_prime)])
        orth_direction_prime = orth_direction_prime / np.linalg.norm(orth_direction_prime)
                   

        ######################################################        
        #8. Do Correction
        ######################################################
                   
        corrected_2d = []
        emb_2d = []
        gender_direction_2d = []
        occupation_direction_2d = []
        occupation_direction_2d_prime = []
       
        def doCorrection():
            for idx, wv in enumerate(vectors(word_vecs)):
                x, dir1, dir2,  dir2_prime, rotated_head  = correction2d_new(rot_matrix, bias_direction, orth_direction, wv)
                emb_2d.append(x)
                gender_direction_2d.append(dir1)
                occupation_direction_2d.append(dir2)
                corrected_2d.append(rotated_head)
                occupation_direction_2d_prime.append(dir2_prime)
   
        doCorrection()
       
       
        rotated_head = np.array(corrected_2d)
        rotated_head = dict(zip(Vocabs, rotated_head))
       
        # rotated_head_allwords = vectors(rotated_head) + point_inter
        rotated_head_allwords = vectors(rotated_head)
       
        # num_basis = 300
        proj_newBasis[:, :2] = rotated_head_allwords
       
        rotated_head_allD = np.matmul(proj_newBasis, rot_matrix)
        # rotated_head_allD = np.matmul(proj_newBasis[:, :num_basis], rot_matrix[:num_basis, :])
        rotated_head_allD = dict(zip(Vocabs, rotated_head_allD))
           
       
        word_vecs = rotated_head_allD.copy()
           

    return word_vecs, spanMatrix

###############################################################################
###############################################################################

def OSCaRPairwise(word_vecs, Vocabs, allPairs, class_Vocab, min_pairDotProd): 
   
    span_matrix = []
    gender_emb = get_vecs(word_vecs, class_Vocab[min_pairDotProd[0]])
    occupation_emb =  get_vecs(word_vecs, class_Vocab[min_pairDotProd[1]])
   
    for iteration in range(1):
        # target.write("Step " + str(iteration + 1) + "\n")
        # print("Iteration " + str(iteration))
       
        ######################################################
        #1. Get v1 & v2
        ######################################################
       
        v1 =  bias_two_means(gender_emb).squeeze()
        v2 =  bias_two_means(occupation_emb).squeeze()
       
        ######################################################
        #3. Check for Angle Change v1 and v2
        ######################################################
       
        theta = angle_between(v1, v2)
        if theta > np.pi / 2:
            v2 = -v2  
       
       
        vecs = vectors(word_vecs)
        rot_matrix = gsConstrained(np.identity(v1.shape[0]), v1, v2)  
        proj_newBasis = np.matmul(vecs, rot_matrix.T)
       
        x_coord = proj_newBasis[:, 0]
        y_coord = proj_newBasis[:, 1]
           
        ######################################################
        #4. Compute v2_prime
        ######################################################
       
        v2_prime = v2 - v1 * (v2.dot(v1))
        v2_prime = v2_prime / np.linalg.norm(v2_prime)
       
        ######################################################
        # Get Span Matrix
        ######################################################
        if iteration == 0:
            span_matrix.append(v1)
            span_matrix.append(v2_prime)
              
           
        ######################################################        
        #5. Update all K-d points to 2-d
        ######################################################
       
        V1_direction = np.array([v1.dot(v1), v1.dot(v2_prime)])
        V1_direction = V1_direction / np.linalg.norm(V1_direction)
       
        V2_direction = np.array([v2.dot(v1), v2.dot(v2_prime)])
        V2_direction = V2_direction / np.linalg.norm(V2_direction)
       
        word_vecs =  dict(zip(Vocabs, np.vstack([x_coord, y_coord]).T))
       
        bias_direction = np.array([v1.dot(v1), v1.dot(v2_prime)])
        bias_direction = bias_direction / np.linalg.norm(bias_direction)
        orth_direction = np.array([v2.dot(v1), v2.dot(v2_prime)])
        orth_direction = orth_direction / np.linalg.norm(orth_direction)
        orth_direction_prime = np.array([v2_prime.dot(v1), v2_prime.dot(v2_prime)])
        orth_direction_prime = orth_direction_prime / np.linalg.norm(orth_direction_prime)
                   
        ######################################################        
        #8. Do Correction
        ######################################################
                   
        corrected_2d = []
        emb_2d = []
        gender_direction_2d = []
        occupation_direction_2d = []
        occupation_direction_2d_prime = []
       
        def doCorrection():
            for idx, wv in enumerate(vectors(word_vecs)):
                x, dir1, dir2,  dir2_prime, rotated_head  = correction2d_new(rot_matrix, bias_direction, orth_direction, wv)
                emb_2d.append(x)
                gender_direction_2d.append(dir1)
                occupation_direction_2d.append(dir2)
                corrected_2d.append(rotated_head)
                occupation_direction_2d_prime.append(dir2_prime)
   
        doCorrection()
       
       
        rotated_head = np.array(corrected_2d)
        rotated_head = dict(zip(Vocabs, rotated_head))
       
        rotated_head_allwords = vectors(rotated_head)  
       
        # num_basis = 300
        proj_newBasis[:, :2] = rotated_head_allwords
       
        rotated_head_allD = np.matmul(proj_newBasis, rot_matrix)
        # rotated_head_allD = np.matmul(proj_newBasis[:, :num_basis], rot_matrix[:num_basis, :])
        rotated_head_allD = dict(zip(Vocabs, rotated_head_allD))
           
       
        word_vecs = rotated_head_allD.copy()
         
 
    return word_vecs, span_matrix


###############################################################################


def ICR(emb_to_debias, dictOfEmb,  ClassLabels_list, ClassVocab, opl_TrainFeat, iter_ICR, mode = None):  
########################
    # race_emb = 0.0
    if mode == "train_ICR":
        all_wordsVocab = ["word-" + str(i+1) for i in range(len(emb_to_debias))]
        base_emb = dict(zip(all_wordsVocab, emb_to_debias.copy()))
    else:
        all_wordsVocab = ["word-" + str(i+1) for i in range(len(opl_TrainFeat))]
        emb_to_debiasVocab = ["bias_word-" + str(i+1) for i in range(len(emb_to_debias))]
        all_wordsVocab =  emb_to_debiasVocab + all_wordsVocab
       
        base_emb = np.concatenate((emb_to_debias, opl_TrainFeat), axis=0)
        base_emb = dict(zip(all_wordsVocab, base_emb))
   
########################  

    target.write("Debiasing" + "\n")
    target.write("" + "\n")
    target.write("###############################################################################" + "\n")
    target.write("Dot Product Scores Before Applying ICR" + "\n")
    target.write("" + "\n")
    all_pairs = list(itertools.combinations(ClassLabels_list,2))
    min_pair_dotProd, min_val_dotProd = Weat_dotprodResult(base_emb, all_pairs, ClassVocab)
   
    for ii in range(iter_ICR):
        
########################         
        feat_dir= "ICR_features/" + 'Iter' + str(ii+1)
        if not os.path.exists(feat_dir):
            os.makedirs(feat_dir)  
        fname_feat = mode + "_feat.npy"
########################

        target.write("" + "\n")
        target.write("" + "\n")
        target.write("###############################################################################" + "\n")
        target.write("ITERATION " + str(ii + 1) + "\n")

        base_emb, SpanMatrix = OSCaRPairwise(base_emb, all_wordsVocab, all_pairs, ClassVocab, min_pair_dotProd)
########################
   
        new_ClassLabels_list =  removeElementList(min_pair_dotProd, ClassLabels_list)
       
        for i in range(len(new_ClassLabels_list)):
           
            closest_class = class_closestSpan(base_emb, SpanMatrix, new_ClassLabels_list, ClassVocab)
            base_emb, SpanMatrix = OSCaRSpan(base_emb, all_wordsVocab, all_pairs, SpanMatrix, ClassVocab, closest_class)
            new_ClassLabels_list = removeElementList([closest_class], new_ClassLabels_list)
           
        Weat_dotprodResult(base_emb, all_pairs, ClassVocab)

        if mode == "train_ICR":
            with open(f'{feat_dir}/{fname_feat}', 'wb') as f:
                np.save(f, get_vecs(base_emb, all_wordsVocab))
        else:
            with open(f'{feat_dir}/{fname_feat}', 'wb') as f:
                np.save(f, get_vecs(base_emb, emb_to_debiasVocab))
     

